diff --git a/LICENSE b/LICENSE index 9315c4efb68..966a609b61e 100644 --- a/LICENSE +++ b/LICENSE @@ -32,6 +32,10 @@ All contributions by Cruise LLC: Copyright (c) 2022 Cruise LLC. All rights reserved. +All contributions by Tri Dao: +Copyright (c) 2024 Tri Dao. +All rights reserved. + All contributions by Arm: Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 442ce7bbe89..9b526e0f2b8 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -168,9 +168,28 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") -# flash_attention sources +# flash_attention hip sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") -file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +# if USE_FLASH_ATTENTION is set, ensure CK instances get generated +if(USE_FLASH_ATTENTION) + if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) + set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) + if(USE_CK_FLASH_ATTENTION STREQUAL "1") + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") + endif() + endif() + message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + endif() + endif() + file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") + file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +endif() #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") @@ -185,6 +204,7 @@ if(USE_FLASH_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip}) + list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip}) list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip}) endif() @@ -325,6 +345,9 @@ if(USE_ROCM) # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) +if(USE_FLASH_ATTENTION) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) +endif() list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3eb7c937538..ee9c762fdb9 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -343,6 +343,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #endif } +at::ROCmFABackend Context::getROCmFAPreferredBackend() const { + return rocm_fa_preferred_backend; +} + +void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { + + // TODO: add plumbing for hasCK for validity checking + TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(), + "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm."); +#ifdef USE_ROCM + if(b == at::ROCmFABackend::Ck) { + static const bool ck_unsupported = []() { + static const std::vector archs = { + "gfx90a", "gfx942" + }; + for (auto index: c10::irange(getNumGPUs())) { + if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + TORCH_WARN_ONCE( + "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); + return true; + } + } + return false; + }(); + if(!ck_unsupported) rocm_fa_preferred_backend = b; + } + else { + rocm_fa_preferred_backend = b; + } +#endif + rocm_fa_preferred_backend = b; +} + + bool Context::allowFP16ReductionCuBLAS() const { return allow_fp16_reduction_cublas; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index c5443c56a9c..87f53c5f197 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -239,6 +240,9 @@ class TORCH_API Context { at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend); + at::ROCmFABackend getROCmFAPreferredBackend() const; + void setROCmFAPreferredBackend(at::ROCmFABackend); + // Note [Enabling Deterministic Operations] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Operations in PyTorch that normally act nondeterministically, but have an @@ -428,6 +432,10 @@ class TORCH_API Context { #endif ? at::BlasBackend::Cublaslt : at::BlasBackend::Cublas; + at::ROCmFABackend rocm_fa_preferred_backend = + c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true + ? at::ROCmFABackend::Ck + : at::ROCmFABackend::Default; #ifdef C10_MOBILE bool release_original_weights = true; #else diff --git a/aten/src/ATen/ROCmFABackend.h b/aten/src/ATen/ROCmFABackend.h new file mode 100644 index 00000000000..6e2844cc8be --- /dev/null +++ b/aten/src/ATen/ROCmFABackend.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +enum class ROCmFABackend : int8_t { Default, AOTriton, Ck }; + +inline std::string ROCmFABackendToString(at::ROCmFABackend backend) { + switch (backend) { + case ROCmFABackend::Default: + return "at::ROCmFABackend::Default"; + case ROCmFABackend::AOTriton: + return "at::ROCmFABackend::AOTriton"; + case ROCmFABackend::Ck: + return "at::ROCmFABackend::Ck"; + default: + TORCH_CHECK(false, "Unknown ROCm flash attention backend") + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::ROCmFABackend backend) { + return stream << ROCmFABackendToString(backend); +} + +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 80eb89600da..c83889cd4be 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -28,7 +28,7 @@ #if USE_ROCM #if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include -#define USE_AOTRITON 1 +#define USE_ROCM_ATTENTION 1 #endif #endif @@ -219,15 +219,21 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM -#if USE_AOTRITON - auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (debug) { - TORCH_WARN( - "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); - } - return false; +#if USE_ROCM_ATTENTION + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + // User explicitly set CK as the flash attention backend. Return true for now + // TODO: Flesh out sanity checks + return true; + } else { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; + } } #else return false; @@ -254,7 +260,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM -#if USE_AOTRITON +#if USE_ROCM_ATTENTION auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 1709bf4d059..11f83ffef36 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -124,7 +124,7 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) { return aotriton::TensorView<0>(reinterpret_cast(ptr), - aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 + aotriton::DType::kUInt64); // AOTriton accepts unsigned int64 } } // namespace aotriton_adapter diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip similarity index 96% rename from aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip rename to aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index dcbac79e317..598105ecef1 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -115,24 +115,18 @@ prepare_philox_arguments(float p_dropout, int64_t counter_offset) { #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") std::tuple -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &out_, // batch_size x seqlen_q x num_heads x head_size - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - std::optional gen_) { - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - // [ROCM specific]: must be at the beginning of the function - // Otherwise check_gpu_arch() checks cuda:0 device. - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream); @@ -242,7 +236,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head } std::tuple -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -408,7 +402,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } std::tuple -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size @@ -559,7 +553,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } std::tuple -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -747,7 +741,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size return { dq, dk, dv, softmax_d }; } - -} // namespace pytorch_fmha +} // namespace pytorch_flash #endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp new file mode 100644 index 00000000000..8115288fb88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp new file mode 100644 index 00000000000..2f21bc13622 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) + { + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.split_stride_dq_acc); + } + else + { // create batch mode kernel arguments + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.batch_stride_dq, + args.batch_stride_dq_acc, + args.split_stride_dq_acc); + } + }(); + + dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +template +struct fmha_bwd_convert_dq_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_convert_dq_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip new file mode 100644 index 00000000000..11ff05e5dbd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip new file mode 100644 index 00000000000..e50f11e48b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip new file mode 100644 index 00000000000..f4235b476ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip new file mode 100644 index 00000000000..1fcb150bb0c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip new file mode 100644 index 00000000000..3da78aabd73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip new file mode 100644 index 00000000000..a5ee51bba50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip new file mode 100644 index 00000000000..d23c280ce1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip new file mode 100644 index 00000000000..b8272029ad3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip new file mode 100644 index 00000000000..c7cb12d5cd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip new file mode 100644 index 00000000000..9cc45b050f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip new file mode 100644 index 00000000000..304214921a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip new file mode 100644 index 00000000000..12d5d208b29 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip new file mode 100644 index 00000000000..24c0a414244 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip new file mode 100644 index 00000000000..34117453c79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip new file mode 100644 index 00000000000..7dff1690728 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip new file mode 100644 index 00000000000..247ecb28f7a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip new file mode 100644 index 00000000000..c83349a6662 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip new file mode 100644 index 00000000000..ecf4290b670 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip new file mode 100644 index 00000000000..2dcf470cb49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip new file mode 100644 index 00000000000..780a28248b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip new file mode 100644 index 00000000000..7df54cbcb25 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip new file mode 100644 index 00000000000..898d17870ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip new file mode 100644 index 00000000000..7e08824282c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip new file mode 100644 index 00000000000..cc44db88758 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip new file mode 100644 index 00000000000..2f4c053bc74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip new file mode 100644 index 00000000000..cd13097b818 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip new file mode 100644 index 00000000000..ad0c065328a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip new file mode 100644 index 00000000000..913e1819e46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip new file mode 100644 index 00000000000..d48e083be95 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip new file mode 100644 index 00000000000..12bfd890c6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip new file mode 100644 index 00000000000..ed3cc93a059 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip new file mode 100644 index 00000000000..8a0da15aa5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip new file mode 100644 index 00000000000..4ed1496285d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip new file mode 100644 index 00000000000..74be5c1f195 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip new file mode 100644 index 00000000000..81b58c933d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip new file mode 100644 index 00000000000..513ed9b0ba6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip new file mode 100644 index 00000000000..7498169d408 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip new file mode 100644 index 00000000000..4248c1f8bc4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip new file mode 100644 index 00000000000..d2f0ff760be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip new file mode 100644 index 00000000000..e9188fb65ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip new file mode 100644 index 00000000000..515bcf6b65e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip new file mode 100644 index 00000000000..dd0fb7aea27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip new file mode 100644 index 00000000000..8a8d141623d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip new file mode 100644 index 00000000000..af08208640a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip new file mode 100644 index 00000000000..6173f431eac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip new file mode 100644 index 00000000000..a461a8ad3c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip new file mode 100644 index 00000000000..36a2eba38f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip new file mode 100644 index 00000000000..9fd67589da9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip new file mode 100644 index 00000000000..2200cc8523f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip new file mode 100644 index 00000000000..e443f8ec729 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip new file mode 100644 index 00000000000..6100c507d6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip new file mode 100644 index 00000000000..e59ac40198d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip new file mode 100644 index 00000000000..9487fc53035 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip new file mode 100644 index 00000000000..d437c6cd60e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip new file mode 100644 index 00000000000..f952c1535c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip new file mode 100644 index 00000000000..0babc637454 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip new file mode 100644 index 00000000000..866ebbec981 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip new file mode 100644 index 00000000000..64c878eccfe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip new file mode 100644 index 00000000000..883582e9c67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip new file mode 100644 index 00000000000..2819cd97411 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip new file mode 100644 index 00000000000..1d293ba529b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip new file mode 100644 index 00000000000..66ca2006e4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip new file mode 100644 index 00000000000..2a53b540abe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip new file mode 100644 index 00000000000..f42aa3ef490 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip new file mode 100644 index 00000000000..256a6393e9d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip new file mode 100644 index 00000000000..9463f524b13 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip new file mode 100644 index 00000000000..e21b2f479e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip new file mode 100644 index 00000000000..e6a493bdbd0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip new file mode 100644 index 00000000000..0f9deffe70e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip new file mode 100644 index 00000000000..656d540789c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip new file mode 100644 index 00000000000..69bc9277c75 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip new file mode 100644 index 00000000000..474717b1a20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip new file mode 100644 index 00000000000..74f365d6186 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip new file mode 100644 index 00000000000..fa7da3a3854 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip new file mode 100644 index 00000000000..d51aeab7bf1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip new file mode 100644 index 00000000000..94cc58dc314 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip new file mode 100644 index 00000000000..544b7c9a8b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip new file mode 100644 index 00000000000..abd89678478 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip new file mode 100644 index 00000000000..e6671df3102 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip new file mode 100644 index 00000000000..29800976b6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip new file mode 100644 index 00000000000..e48d26d4247 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip new file mode 100644 index 00000000000..27872f6fae4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip new file mode 100644 index 00000000000..eee7d6d67f1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip new file mode 100644 index 00000000000..da711799b09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip new file mode 100644 index 00000000000..673bab07358 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip new file mode 100644 index 00000000000..322083d0ac3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip new file mode 100644 index 00000000000..647879bfac5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip new file mode 100644 index 00000000000..f63d42b45ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip new file mode 100644 index 00000000000..ff32eecf1e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip new file mode 100644 index 00000000000..85408d275fe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip new file mode 100644 index 00000000000..21b0570baad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip new file mode 100644 index 00000000000..2b98ae63673 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip new file mode 100644 index 00000000000..7520a8551e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip new file mode 100644 index 00000000000..324bb25c120 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip new file mode 100644 index 00000000000..fde747894ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip new file mode 100644 index 00000000000..e3bcce7ab85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip new file mode 100644 index 00000000000..0a92513862c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip new file mode 100644 index 00000000000..19dc9a2ff49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip new file mode 100644 index 00000000000..e7b6c90da08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip new file mode 100644 index 00000000000..ea61a8c916e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip new file mode 100644 index 00000000000..2d8d09cf2eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip new file mode 100644 index 00000000000..8473617ba95 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip new file mode 100644 index 00000000000..78b0ce5c6f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip new file mode 100644 index 00000000000..1ce29ae2f99 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip new file mode 100644 index 00000000000..c54a987cac5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip new file mode 100644 index 00000000000..78f56b54f70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip new file mode 100644 index 00000000000..78dfc5a9b1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip new file mode 100644 index 00000000000..db5239ecd29 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip new file mode 100644 index 00000000000..79a2821f3f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip new file mode 100644 index 00000000000..11364420bf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip new file mode 100644 index 00000000000..db47fbad4d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip new file mode 100644 index 00000000000..c086f7b6e86 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip new file mode 100644 index 00000000000..644f756253b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip new file mode 100644 index 00000000000..59ec37be84d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip new file mode 100644 index 00000000000..8a1b12f59bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip new file mode 100644 index 00000000000..8182111c4a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip new file mode 100644 index 00000000000..e6cf970a99a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip new file mode 100644 index 00000000000..d6cfa8a5092 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip new file mode 100644 index 00000000000..ccfa82d52c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip new file mode 100644 index 00000000000..05fde574742 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip new file mode 100644 index 00000000000..4b4b70a80f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip new file mode 100644 index 00000000000..3ee6242ee00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip new file mode 100644 index 00000000000..3e609731a5d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip new file mode 100644 index 00000000000..0380a534950 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip new file mode 100644 index 00000000000..24ece6c55a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip new file mode 100644 index 00000000000..1cd4a3a3f7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip new file mode 100644 index 00000000000..1a9606077e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip new file mode 100644 index 00000000000..ba338ae6a36 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip new file mode 100644 index 00000000000..10220b3ed4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip new file mode 100644 index 00000000000..6bb1fa39154 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip new file mode 100644 index 00000000000..876953de89e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip new file mode 100644 index 00000000000..5b0886c2e10 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip new file mode 100644 index 00000000000..91425485ae5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip new file mode 100644 index 00000000000..7f0d93f40f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip new file mode 100644 index 00000000000..0948511ba53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip new file mode 100644 index 00000000000..996900f7009 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip new file mode 100644 index 00000000000..9c4507d8322 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip new file mode 100644 index 00000000000..8129f0f62d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip new file mode 100644 index 00000000000..07dfd3b3131 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip new file mode 100644 index 00000000000..30a8f33ad89 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip new file mode 100644 index 00000000000..1eb2afd6ab7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip new file mode 100644 index 00000000000..86e10c62ce9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip new file mode 100644 index 00000000000..cb483891a41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip new file mode 100644 index 00000000000..cc7e1a08d5d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip new file mode 100644 index 00000000000..032c888c89e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip new file mode 100644 index 00000000000..31d8a42c2d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip new file mode 100644 index 00000000000..47e94e66edd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip new file mode 100644 index 00000000000..bfd55686121 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip new file mode 100644 index 00000000000..b8d8881a17d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip new file mode 100644 index 00000000000..7d5612962a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip new file mode 100644 index 00000000000..f8ac18f9892 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip new file mode 100644 index 00000000000..39a73952876 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip new file mode 100644 index 00000000000..1c9a01c0701 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip new file mode 100644 index 00000000000..61478645543 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip new file mode 100644 index 00000000000..5d0bf7e6ec0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip new file mode 100644 index 00000000000..d90d53b8a77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip new file mode 100644 index 00000000000..2f6d56c17af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip new file mode 100644 index 00000000000..eb7e1a33928 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip new file mode 100644 index 00000000000..abd04eb76ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip new file mode 100644 index 00000000000..4e2fc118bc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip new file mode 100644 index 00000000000..0c4b0b4ad5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip new file mode 100644 index 00000000000..793d43cad59 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip new file mode 100644 index 00000000000..6006941b699 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip new file mode 100644 index 00000000000..dc4ea63a6b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip new file mode 100644 index 00000000000..260c3559a59 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip new file mode 100644 index 00000000000..46fb46ef11a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip new file mode 100644 index 00000000000..91b4d887de6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip new file mode 100644 index 00000000000..65fac8b48c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip new file mode 100644 index 00000000000..85f56bdca7a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip new file mode 100644 index 00000000000..a80c66dfaae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip new file mode 100644 index 00000000000..8ed77fdae46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip new file mode 100644 index 00000000000..55dc832eb93 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip new file mode 100644 index 00000000000..68ef183d699 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip new file mode 100644 index 00000000000..d3adafe0b28 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip new file mode 100644 index 00000000000..4f5830b7d44 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip new file mode 100644 index 00000000000..86c7f2cba0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip new file mode 100644 index 00000000000..919a3ac828c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip new file mode 100644 index 00000000000..d41b5dcef34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip new file mode 100644 index 00000000000..d9bc327995b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip new file mode 100644 index 00000000000..da9c2fe71e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip new file mode 100644 index 00000000000..a0ac88f518c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip new file mode 100644 index 00000000000..62ef602df6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip new file mode 100644 index 00000000000..bdc5cbba2ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip new file mode 100644 index 00000000000..997da7f6a12 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip new file mode 100644 index 00000000000..304a37b4489 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip new file mode 100644 index 00000000000..179dbf59798 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip new file mode 100644 index 00000000000..ffab9800ad5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip new file mode 100644 index 00000000000..996f9773c4c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip new file mode 100644 index 00000000000..8eee4ffa35d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip new file mode 100644 index 00000000000..b353542ec8c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip new file mode 100644 index 00000000000..69ecf815ab6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip new file mode 100644 index 00000000000..d03aa0275bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip new file mode 100644 index 00000000000..765925ad32e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip new file mode 100644 index 00000000000..d3ceedc3ba2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip new file mode 100644 index 00000000000..45412a0172d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip new file mode 100644 index 00000000000..9d1782fd4f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip new file mode 100644 index 00000000000..81ea5cb5728 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip new file mode 100644 index 00000000000..07305c4de11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip new file mode 100644 index 00000000000..084f97abecc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip new file mode 100644 index 00000000000..6c9b49e2ebd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip new file mode 100644 index 00000000000..08cb9227ae1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip new file mode 100644 index 00000000000..f32d6823f03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip new file mode 100644 index 00000000000..de2671c488a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip new file mode 100644 index 00000000000..7605150e831 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip new file mode 100644 index 00000000000..f37ebaa9855 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip new file mode 100644 index 00000000000..5437e8aac26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip new file mode 100644 index 00000000000..903d320799e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip new file mode 100644 index 00000000000..cc965b6f707 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip new file mode 100644 index 00000000000..30608a330e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip @@ -0,0 +1,1965 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){ + float r = -1; + if(t.data_type.compare("fp16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + + } + else if(t.data_type.compare("bf16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + + } + + return r; +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip new file mode 100644 index 00000000000..f89fcd9026d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip new file mode 100644 index 00000000000..a7b6ac361d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip new file mode 100644 index 00000000000..26c89e01291 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip new file mode 100644 index 00000000000..a8b77d2a4e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip new file mode 100644 index 00000000000..5f35d549931 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip new file mode 100644 index 00000000000..132239fd081 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip new file mode 100644 index 00000000000..fa10992280f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip new file mode 100644 index 00000000000..9966b0db808 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip new file mode 100644 index 00000000000..74dcf07cfb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip new file mode 100644 index 00000000000..8840038e578 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip new file mode 100644 index 00000000000..ff58791e85f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip new file mode 100644 index 00000000000..696d403eda6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip new file mode 100644 index 00000000000..8590d328eb8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip new file mode 100644 index 00000000000..75d061fed3d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip new file mode 100644 index 00000000000..ef6b29cf938 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip new file mode 100644 index 00000000000..9869c96ce0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip new file mode 100644 index 00000000000..6a5fef08dcb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip new file mode 100644 index 00000000000..9f0ab607590 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip new file mode 100644 index 00000000000..fdbfe7482c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip new file mode 100644 index 00000000000..4c3c7d43b9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip new file mode 100644 index 00000000000..2982627aa80 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip new file mode 100644 index 00000000000..25a407fd40f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip new file mode 100644 index 00000000000..298a4e9b153 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip new file mode 100644 index 00000000000..23f4ba8fab5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip new file mode 100644 index 00000000000..790586338a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip new file mode 100644 index 00000000000..7ee363221c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip new file mode 100644 index 00000000000..cdc06572111 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip new file mode 100644 index 00000000000..3cb13e21042 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip new file mode 100644 index 00000000000..038304e4237 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip new file mode 100644 index 00000000000..5f9a6cd1607 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip new file mode 100644 index 00000000000..a66950b1b4b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip new file mode 100644 index 00000000000..05e5cc3248c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip new file mode 100644 index 00000000000..e723a9c2bf4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip new file mode 100644 index 00000000000..9fcad67cae6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip new file mode 100644 index 00000000000..dab9ce44e65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip new file mode 100644 index 00000000000..4c82e0eed5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip new file mode 100644 index 00000000000..9370d29e805 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip new file mode 100644 index 00000000000..80f87ac9f3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip new file mode 100644 index 00000000000..720880cc4ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip new file mode 100644 index 00000000000..af05c411c24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip new file mode 100644 index 00000000000..3b30ec85c77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip new file mode 100644 index 00000000000..7e33230130a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip new file mode 100644 index 00000000000..a3b06ae0a85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip new file mode 100644 index 00000000000..a859ad0a5f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip new file mode 100644 index 00000000000..2fadab0c4c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip new file mode 100644 index 00000000000..d12086cc2a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip new file mode 100644 index 00000000000..1412dad1cb3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip new file mode 100644 index 00000000000..f66e9193a3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip new file mode 100644 index 00000000000..0cff307021e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip new file mode 100644 index 00000000000..eb66c2a58e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip new file mode 100644 index 00000000000..14c3b0eb016 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip new file mode 100644 index 00000000000..be2f64ccab6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip new file mode 100644 index 00000000000..d1c5a3fa924 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip new file mode 100644 index 00000000000..d26a92e4eca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip new file mode 100644 index 00000000000..6d900a17bfa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip new file mode 100644 index 00000000000..c42575bbb15 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip new file mode 100644 index 00000000000..e78ae2da943 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip new file mode 100644 index 00000000000..a55761ed058 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip new file mode 100644 index 00000000000..f86fd21a3de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip new file mode 100644 index 00000000000..684f699a20a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip new file mode 100644 index 00000000000..ec3da5375c0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip new file mode 100644 index 00000000000..e57b6b2464d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip new file mode 100644 index 00000000000..2d7d691e854 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip new file mode 100644 index 00000000000..9dcaec3e6ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip new file mode 100644 index 00000000000..c8befdad639 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip new file mode 100644 index 00000000000..f385d543ec4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip new file mode 100644 index 00000000000..171cdac4f03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip new file mode 100644 index 00000000000..fbed56c47d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip new file mode 100644 index 00000000000..89bc7dd7e2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip new file mode 100644 index 00000000000..a4af048e43f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip new file mode 100644 index 00000000000..525dc28431d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip new file mode 100644 index 00000000000..f6d4f632d33 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip new file mode 100644 index 00000000000..30ba2020a4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip new file mode 100644 index 00000000000..af5313e1cab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip new file mode 100644 index 00000000000..4544758fe65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip new file mode 100644 index 00000000000..9bfdb6aef82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip new file mode 100644 index 00000000000..95c8658d805 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip new file mode 100644 index 00000000000..a94494a3e4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip new file mode 100644 index 00000000000..40543edbadb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip new file mode 100644 index 00000000000..06f2ffea70c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip new file mode 100644 index 00000000000..2e857611381 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip new file mode 100644 index 00000000000..c4a5d12c6e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip new file mode 100644 index 00000000000..38d2df11bb8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip new file mode 100644 index 00000000000..c0655b73a18 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip new file mode 100644 index 00000000000..7413b4f6e86 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip new file mode 100644 index 00000000000..baca96ca671 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip new file mode 100644 index 00000000000..eae16006f2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip new file mode 100644 index 00000000000..de4bee27c0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip new file mode 100644 index 00000000000..541d95b9804 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip new file mode 100644 index 00000000000..6d221ce6c9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip new file mode 100644 index 00000000000..90e416b7df9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip new file mode 100644 index 00000000000..82e69e6b972 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip new file mode 100644 index 00000000000..644c096be64 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip new file mode 100644 index 00000000000..f1e151bb0c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip new file mode 100644 index 00000000000..6253aaf7b8d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip new file mode 100644 index 00000000000..ea580c781dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip new file mode 100644 index 00000000000..7c2551e9486 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip new file mode 100644 index 00000000000..f138f8cfe79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip new file mode 100644 index 00000000000..f80cf261633 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip new file mode 100644 index 00000000000..debe3536bf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip new file mode 100644 index 00000000000..7ae88583ad5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip new file mode 100644 index 00000000000..52c67012d4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip new file mode 100644 index 00000000000..0ae1cf1c5cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip new file mode 100644 index 00000000000..1ef34b4126b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip new file mode 100644 index 00000000000..62ac218caad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip new file mode 100644 index 00000000000..5417d1145fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip new file mode 100644 index 00000000000..9537b60aadd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip new file mode 100644 index 00000000000..991da6f6a43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip new file mode 100644 index 00000000000..32988018b50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip new file mode 100644 index 00000000000..ad3e20331fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip new file mode 100644 index 00000000000..244458feef1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip new file mode 100644 index 00000000000..8c4e06043bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip new file mode 100644 index 00000000000..36cc57f19a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip new file mode 100644 index 00000000000..a95c1e56eda --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip new file mode 100644 index 00000000000..2d54c326f1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip new file mode 100644 index 00000000000..e2166d14ee9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip new file mode 100644 index 00000000000..4f11e3066c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip new file mode 100644 index 00000000000..66f2134b17b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip new file mode 100644 index 00000000000..81ecee533c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip new file mode 100644 index 00000000000..1465eede4a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip new file mode 100644 index 00000000000..7c747ffa852 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip new file mode 100644 index 00000000000..7bfcf3546a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip new file mode 100644 index 00000000000..7041dde51e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip new file mode 100644 index 00000000000..653c7f1a947 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip new file mode 100644 index 00000000000..997e311f175 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip new file mode 100644 index 00000000000..f76fafe72e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip new file mode 100644 index 00000000000..8ac1c3f37dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip new file mode 100644 index 00000000000..d0f86685035 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip new file mode 100644 index 00000000000..da3fbc3f584 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip new file mode 100644 index 00000000000..f4ce61655c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip new file mode 100644 index 00000000000..90d14935193 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip new file mode 100644 index 00000000000..7e4018d6cad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip new file mode 100644 index 00000000000..1980cd16e4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip new file mode 100644 index 00000000000..dc9e3df1cb4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip new file mode 100644 index 00000000000..ac60af950a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip new file mode 100644 index 00000000000..4521da9fba5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip new file mode 100644 index 00000000000..a3448cf7d45 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip new file mode 100644 index 00000000000..ffabb7038de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip new file mode 100644 index 00000000000..8b920baad7d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip new file mode 100644 index 00000000000..989fcedea2a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip new file mode 100644 index 00000000000..1a3a6e656bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip new file mode 100644 index 00000000000..902e50fffcb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip new file mode 100644 index 00000000000..eaa36741dfc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip new file mode 100644 index 00000000000..9b8ea044844 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip new file mode 100644 index 00000000000..b5a1f73e406 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip new file mode 100644 index 00000000000..4d706362bb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip new file mode 100644 index 00000000000..4e73287a50f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip new file mode 100644 index 00000000000..dc18022abcf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip new file mode 100644 index 00000000000..accc3bf3513 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip new file mode 100644 index 00000000000..41dad6deb6b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip new file mode 100644 index 00000000000..19a0b745976 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip new file mode 100644 index 00000000000..3a7c3a96d14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip new file mode 100644 index 00000000000..595552c2eb4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip new file mode 100644 index 00000000000..7d8501d7ac3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip new file mode 100644 index 00000000000..1e69eeab252 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip new file mode 100644 index 00000000000..3a4bf5f9985 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip new file mode 100644 index 00000000000..f1e5c3091e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip new file mode 100644 index 00000000000..3dc985e15d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip new file mode 100644 index 00000000000..5656854a8dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip new file mode 100644 index 00000000000..b52b67fa2fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip new file mode 100644 index 00000000000..cc6c0f0fbad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip new file mode 100644 index 00000000000..32c52d79a31 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip new file mode 100644 index 00000000000..c8f8761d6e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip new file mode 100644 index 00000000000..2e522d847bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip new file mode 100644 index 00000000000..878cbe7e677 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip new file mode 100644 index 00000000000..92b96688e2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip new file mode 100644 index 00000000000..7d3ca38f352 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip new file mode 100644 index 00000000000..8815c227c29 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip new file mode 100644 index 00000000000..2e1046225d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip new file mode 100644 index 00000000000..54bbc59eded --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip new file mode 100644 index 00000000000..140bea9b300 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip new file mode 100644 index 00000000000..73ca1ca886d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip new file mode 100644 index 00000000000..cf909ff4b49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip new file mode 100644 index 00000000000..0c33783352c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip new file mode 100644 index 00000000000..5bdb954c746 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip new file mode 100644 index 00000000000..b434aed4fca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip new file mode 100644 index 00000000000..d68768627db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip new file mode 100644 index 00000000000..51beb87a2b3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip new file mode 100644 index 00000000000..adf03ddef03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip new file mode 100644 index 00000000000..7f7027e9a8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip new file mode 100644 index 00000000000..ddbca1e546c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip new file mode 100644 index 00000000000..58b13dfa55d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip new file mode 100644 index 00000000000..3e4446f52ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip new file mode 100644 index 00000000000..83b2644097c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip new file mode 100644 index 00000000000..98574609d77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip new file mode 100644 index 00000000000..5e1fe0ac82c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip new file mode 100644 index 00000000000..9f420d7fdfc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip new file mode 100644 index 00000000000..500ae2c9a2e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip new file mode 100644 index 00000000000..fd0a60eb88e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip new file mode 100644 index 00000000000..ecdcdfe1495 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip new file mode 100644 index 00000000000..0a74851481f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip new file mode 100644 index 00000000000..70e55fe4b94 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip new file mode 100644 index 00000000000..9a504bf3b41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip new file mode 100644 index 00000000000..52bec517464 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip new file mode 100644 index 00000000000..1aec4fa96c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip new file mode 100644 index 00000000000..7b20483d49c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip new file mode 100644 index 00000000000..907a5f0d39d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip new file mode 100644 index 00000000000..44d320675fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip new file mode 100644 index 00000000000..53602634b9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip new file mode 100644 index 00000000000..ed72982e3e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip new file mode 100644 index 00000000000..d4fb0830305 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip new file mode 100644 index 00000000000..58548b3c949 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip new file mode 100644 index 00000000000..5836a4b4a2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip new file mode 100644 index 00000000000..46e36eba6ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip new file mode 100644 index 00000000000..1a297aa8e40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip new file mode 100644 index 00000000000..b8e7a9eb57d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip new file mode 100644 index 00000000000..0955f50c488 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip new file mode 100644 index 00000000000..951ef1eb5db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip new file mode 100644 index 00000000000..872a38cd720 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip new file mode 100644 index 00000000000..3ec1276181a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip new file mode 100644 index 00000000000..2a8d74f47e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip new file mode 100644 index 00000000000..0ddbbcdaaa5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip new file mode 100644 index 00000000000..dbef7fd0994 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip new file mode 100644 index 00000000000..457fc9f0fc9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip new file mode 100644 index 00000000000..d937f8f21b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip new file mode 100644 index 00000000000..ff32e3786f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip new file mode 100644 index 00000000000..5805472bea7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip new file mode 100644 index 00000000000..87d6878c3e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip new file mode 100644 index 00000000000..c8cead57404 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip new file mode 100644 index 00000000000..8dd1c686396 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip new file mode 100644 index 00000000000..5a321a93c5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip new file mode 100644 index 00000000000..3016cd469f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip new file mode 100644 index 00000000000..0bb1a69aa48 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip new file mode 100644 index 00000000000..721ff085113 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip new file mode 100644 index 00000000000..4cc230a14a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip new file mode 100644 index 00000000000..ae316743ab6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip new file mode 100644 index 00000000000..f9175f9604c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip new file mode 100644 index 00000000000..128a085e485 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip new file mode 100644 index 00000000000..5654f69141f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip new file mode 100644 index 00000000000..5798ef24307 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip new file mode 100644 index 00000000000..fd12889cf47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip new file mode 100644 index 00000000000..5115f25ea32 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip new file mode 100644 index 00000000000..dfeef4e22f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip new file mode 100644 index 00000000000..1121d713631 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip new file mode 100644 index 00000000000..9e84ec8cd3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip new file mode 100644 index 00000000000..68f833e3b58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip new file mode 100644 index 00000000000..8d24261fb72 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip new file mode 100644 index 00000000000..86d662d1374 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip new file mode 100644 index 00000000000..0e52b9cee61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip new file mode 100644 index 00000000000..e480475f08e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip new file mode 100644 index 00000000000..2e713305b18 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip new file mode 100644 index 00000000000..3d4188634ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip new file mode 100644 index 00000000000..45354957b43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip new file mode 100644 index 00000000000..e98c243b41b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip new file mode 100644 index 00000000000..2d38d9ce9e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip new file mode 100644 index 00000000000..e2bdb517ead --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip new file mode 100644 index 00000000000..fd6a937fec6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip new file mode 100644 index 00000000000..cad897f0101 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip new file mode 100644 index 00000000000..c110e4b7f6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip new file mode 100644 index 00000000000..ae68340237a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip new file mode 100644 index 00000000000..c963f4cd625 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip new file mode 100644 index 00000000000..875e8ce0f19 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip new file mode 100644 index 00000000000..fd0e82c01b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip new file mode 100644 index 00000000000..1ed5d01a6d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip new file mode 100644 index 00000000000..71bc0a9e7c0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip new file mode 100644 index 00000000000..979c0bfdda3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip new file mode 100644 index 00000000000..c8e79b127c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip new file mode 100644 index 00000000000..003b14e205e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip new file mode 100644 index 00000000000..e703b937148 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip new file mode 100644 index 00000000000..0c184f43ee7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip new file mode 100644 index 00000000000..f24e798a332 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip new file mode 100644 index 00000000000..019ebafe587 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip new file mode 100644 index 00000000000..2f53b38c163 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip new file mode 100644 index 00000000000..bd1baeb57f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip new file mode 100644 index 00000000000..18b57f6f108 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip new file mode 100644 index 00000000000..b662b55b69b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip new file mode 100644 index 00000000000..1eb0f4fd3ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip new file mode 100644 index 00000000000..83ef09fb103 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip new file mode 100644 index 00000000000..dd3ded85bb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip new file mode 100644 index 00000000000..306bfbcf251 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip new file mode 100644 index 00000000000..3fac933775f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip new file mode 100644 index 00000000000..47c15f956e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip new file mode 100644 index 00000000000..a0d0d02b999 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip new file mode 100644 index 00000000000..b376dd808e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip new file mode 100644 index 00000000000..c479c02017d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip new file mode 100644 index 00000000000..98aac8b5ab7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip new file mode 100644 index 00000000000..b2f58522f16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip new file mode 100644 index 00000000000..d419d50e803 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip new file mode 100644 index 00000000000..3e326ee37af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip new file mode 100644 index 00000000000..a3c0807ec0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip new file mode 100644 index 00000000000..1e96c426d21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip new file mode 100644 index 00000000000..2f1461f5712 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip new file mode 100644 index 00000000000..fcc0d7ac7bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip new file mode 100644 index 00000000000..62c4fdb191b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip new file mode 100644 index 00000000000..7aea6d7a4c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip new file mode 100644 index 00000000000..6f48793bb26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip new file mode 100644 index 00000000000..ab9446438c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip new file mode 100644 index 00000000000..1968b601b68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip new file mode 100644 index 00000000000..dafbd57a874 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip new file mode 100644 index 00000000000..70e353ec443 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip new file mode 100644 index 00000000000..62f5cfe2252 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip new file mode 100644 index 00000000000..b6b4283da67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip new file mode 100644 index 00000000000..8f05e41b740 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip new file mode 100644 index 00000000000..a14760b5823 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip new file mode 100644 index 00000000000..321626f2d28 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip new file mode 100644 index 00000000000..cd9499e0354 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip new file mode 100644 index 00000000000..5ff6782c43c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip new file mode 100644 index 00000000000..3c2f785595f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip new file mode 100644 index 00000000000..978f29329e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip new file mode 100644 index 00000000000..a7206f643dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip new file mode 100644 index 00000000000..ade3464bbb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip new file mode 100644 index 00000000000..3cd9c78e3e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip new file mode 100644 index 00000000000..ff6b89fae8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip new file mode 100644 index 00000000000..9f6fd356789 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip new file mode 100644 index 00000000000..33bfc80603c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip new file mode 100644 index 00000000000..9d05594cf21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip new file mode 100644 index 00000000000..a78256f1548 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip new file mode 100644 index 00000000000..220a5a4f2ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip new file mode 100644 index 00000000000..309681f33ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip new file mode 100644 index 00000000000..8b2a1e2a3b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip new file mode 100644 index 00000000000..ae19c51db82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip new file mode 100644 index 00000000000..c7c7457b800 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip new file mode 100644 index 00000000000..b17bad6b12e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip new file mode 100644 index 00000000000..c94863ece04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip new file mode 100644 index 00000000000..56917606f08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip new file mode 100644 index 00000000000..a23d30926d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip new file mode 100644 index 00000000000..a29e1beb4fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip new file mode 100644 index 00000000000..0301c13f7dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip new file mode 100644 index 00000000000..2a189a65a94 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip new file mode 100644 index 00000000000..ad8f080b495 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip new file mode 100644 index 00000000000..9cbb835abd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip new file mode 100644 index 00000000000..0eff4b86984 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip new file mode 100644 index 00000000000..f4bb78b0511 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip new file mode 100644 index 00000000000..3cb597dcfcd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip new file mode 100644 index 00000000000..e023c01ef00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip new file mode 100644 index 00000000000..006da14cc91 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip new file mode 100644 index 00000000000..83d2fd0dc85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip new file mode 100644 index 00000000000..f3e0cbb7a06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip new file mode 100644 index 00000000000..d34f2a8e97d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip new file mode 100644 index 00000000000..cfac59538b5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip new file mode 100644 index 00000000000..19df50232d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip new file mode 100644 index 00000000000..7c7e7201db3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip new file mode 100644 index 00000000000..ddd06638883 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip new file mode 100644 index 00000000000..694fa722d3f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip new file mode 100644 index 00000000000..30c20f80593 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip new file mode 100644 index 00000000000..74f703b455b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip new file mode 100644 index 00000000000..8baadf07505 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip new file mode 100644 index 00000000000..a75f6965bdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip new file mode 100644 index 00000000000..948e9dea9c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip new file mode 100644 index 00000000000..0bc5de1e5a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip new file mode 100644 index 00000000000..329cee50fdd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip new file mode 100644 index 00000000000..29af28afdda --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip new file mode 100644 index 00000000000..5ca0f423c56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip new file mode 100644 index 00000000000..65822931820 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip new file mode 100644 index 00000000000..487dce37eeb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip new file mode 100644 index 00000000000..d373ae30a1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip new file mode 100644 index 00000000000..3a7fcef5e0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip new file mode 100644 index 00000000000..3a7a5e5295c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip new file mode 100644 index 00000000000..8ba7618b58c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip new file mode 100644 index 00000000000..4ee9df9540e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip new file mode 100644 index 00000000000..0a39028591c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip new file mode 100644 index 00000000000..10933b2ba7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip new file mode 100644 index 00000000000..7536bd3bba0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip new file mode 100644 index 00000000000..c2cbb592946 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip new file mode 100644 index 00000000000..2c5ad3e47e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip new file mode 100644 index 00000000000..ae7425e5dbf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip new file mode 100644 index 00000000000..cc3c001e9e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip new file mode 100644 index 00000000000..78cb9146355 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip new file mode 100644 index 00000000000..207b155c594 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip new file mode 100644 index 00000000000..9bc545c3cd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip new file mode 100644 index 00000000000..73e33032d34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip new file mode 100644 index 00000000000..ab6f84bf934 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip new file mode 100644 index 00000000000..107834d3836 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip new file mode 100644 index 00000000000..0f4093125b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip new file mode 100644 index 00000000000..965cd600b8a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip new file mode 100644 index 00000000000..e0d9f465a9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip new file mode 100644 index 00000000000..280f3877a19 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip new file mode 100644 index 00000000000..e03645b96d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip new file mode 100644 index 00000000000..8f794db20e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip new file mode 100644 index 00000000000..090f8fb619d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip new file mode 100644 index 00000000000..869bf7f2627 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip new file mode 100644 index 00000000000..f09735bdb22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip new file mode 100644 index 00000000000..b3cf01ae478 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip new file mode 100644 index 00000000000..a3e5a78f970 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip new file mode 100644 index 00000000000..7c4370995ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip new file mode 100644 index 00000000000..e90dcb85f35 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip new file mode 100644 index 00000000000..17bcce9dfb4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip new file mode 100644 index 00000000000..a8c94b4bffd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip new file mode 100644 index 00000000000..b28531693ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip new file mode 100644 index 00000000000..ea856e55da7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip new file mode 100644 index 00000000000..c1738ae45e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip new file mode 100644 index 00000000000..3ba1ce44f26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip new file mode 100644 index 00000000000..2a28bfc26de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip new file mode 100644 index 00000000000..cd233c9b3bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip new file mode 100644 index 00000000000..09ffdabfe31 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip new file mode 100644 index 00000000000..e633af5ce3e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip new file mode 100644 index 00000000000..4868b21c49f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip new file mode 100644 index 00000000000..fc5da06f3e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip new file mode 100644 index 00000000000..ba839ff1f70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip new file mode 100644 index 00000000000..7e68e05bdd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip new file mode 100644 index 00000000000..dcd91d230a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip new file mode 100644 index 00000000000..40d7912b548 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip new file mode 100644 index 00000000000..60989e7b50e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip new file mode 100644 index 00000000000..f984328259b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip new file mode 100644 index 00000000000..af3ba52165a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip new file mode 100644 index 00000000000..c7c98a48ed2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip new file mode 100644 index 00000000000..dec8c583648 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip new file mode 100644 index 00000000000..a940c486f41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip new file mode 100644 index 00000000000..ab8b6d3cad2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip new file mode 100644 index 00000000000..d64de675647 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip new file mode 100644 index 00000000000..2cc28aa9544 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip new file mode 100644 index 00000000000..d74acfbc2d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip new file mode 100644 index 00000000000..f0fd94abdbf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip new file mode 100644 index 00000000000..d95bdc8c5da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip new file mode 100644 index 00000000000..137e4967d03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip new file mode 100644 index 00000000000..8aae4b5333d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip new file mode 100644 index 00000000000..9e573698aa1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip new file mode 100644 index 00000000000..6970d7ae20d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip new file mode 100644 index 00000000000..f7df3f147d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip new file mode 100644 index 00000000000..06a5903fcd9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip new file mode 100644 index 00000000000..ac51939ea41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip new file mode 100644 index 00000000000..037237d41d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip new file mode 100644 index 00000000000..ec0eb42fbdd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip new file mode 100644 index 00000000000..8a402676fd3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip new file mode 100644 index 00000000000..14ce32a0c54 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip new file mode 100644 index 00000000000..70204c41a3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip new file mode 100644 index 00000000000..94ba3e2921f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip new file mode 100644 index 00000000000..d809c4f367b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip new file mode 100644 index 00000000000..087c17dfc24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip new file mode 100644 index 00000000000..916aab9e71a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip new file mode 100644 index 00000000000..6f75e3aadbd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip new file mode 100644 index 00000000000..55a2358fa98 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip new file mode 100644 index 00000000000..b9d2a7223c0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip new file mode 100644 index 00000000000..e9a73092d22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip new file mode 100644 index 00000000000..43bec7d7460 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip new file mode 100644 index 00000000000..b7a143108a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip new file mode 100644 index 00000000000..beda973479b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip new file mode 100644 index 00000000000..b82f1883820 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip new file mode 100644 index 00000000000..44eec56b09d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip new file mode 100644 index 00000000000..bff7d9de06d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip new file mode 100644 index 00000000000..aa735704459 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip new file mode 100644 index 00000000000..db75c7e723e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip new file mode 100644 index 00000000000..9451a13d169 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip new file mode 100644 index 00000000000..591d9a51cd2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip new file mode 100644 index 00000000000..6c379707293 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip new file mode 100644 index 00000000000..47b5b53a6b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip new file mode 100644 index 00000000000..0f617ac46e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip new file mode 100644 index 00000000000..b0b63f90c98 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip new file mode 100644 index 00000000000..068ce698b3f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip new file mode 100644 index 00000000000..00d0f0a1541 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip new file mode 100644 index 00000000000..186b55badcd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip new file mode 100644 index 00000000000..db36be36544 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip new file mode 100644 index 00000000000..399670f4e8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip new file mode 100644 index 00000000000..28eaec0e987 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip new file mode 100644 index 00000000000..82964aaa461 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip @@ -0,0 +1,14399 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_(s_, a); } + ); +#else + return 0.0; +#endif +} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){ + float r = -1; + if(t.data_type.compare("fp16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + + } + else if(t.data_type.compare("bf16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + + } + + return r; +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip new file mode 100644 index 00000000000..ca8bae01086 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip new file mode 100644 index 00000000000..b6cf1cb4411 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip new file mode 100644 index 00000000000..e4098c9f741 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip new file mode 100644 index 00000000000..f0a3ce4806d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip new file mode 100644 index 00000000000..544121f30b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip new file mode 100644 index 00000000000..4f4e7b7ce42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip new file mode 100644 index 00000000000..ca7a148219a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip new file mode 100644 index 00000000000..51c8dd8841e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip new file mode 100644 index 00000000000..f837edec025 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip new file mode 100644 index 00000000000..fe6ae249435 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip new file mode 100644 index 00000000000..d63a67ffba9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip new file mode 100644 index 00000000000..4e4e62b82a8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip new file mode 100644 index 00000000000..147c5f04af0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip new file mode 100644 index 00000000000..e57559fe906 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip new file mode 100644 index 00000000000..072e763c87e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip new file mode 100644 index 00000000000..77e48a6748f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip new file mode 100644 index 00000000000..5c55f1f2502 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip new file mode 100644 index 00000000000..9079d495034 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip new file mode 100644 index 00000000000..e9d1cecfd89 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip new file mode 100644 index 00000000000..398cc46c34a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip new file mode 100644 index 00000000000..2fe7b7fef1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip new file mode 100644 index 00000000000..b98e8eb51d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip new file mode 100644 index 00000000000..b9bf06ceaf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip new file mode 100644 index 00000000000..15753ff0a0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip new file mode 100644 index 00000000000..ac0e634a6ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip new file mode 100644 index 00000000000..2d294dc9e07 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip new file mode 100644 index 00000000000..6bbb194c6f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip new file mode 100644 index 00000000000..605258d502e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip new file mode 100644 index 00000000000..6375ef7d3f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip new file mode 100644 index 00000000000..8b7d6afbb74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip new file mode 100644 index 00000000000..9aafba9d588 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip new file mode 100644 index 00000000000..4713711a1ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip new file mode 100644 index 00000000000..42f40802671 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip new file mode 100644 index 00000000000..b2d7b8a8c80 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip new file mode 100644 index 00000000000..49d85122882 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip new file mode 100644 index 00000000000..811bee38dad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip new file mode 100644 index 00000000000..317b0ffb52e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip new file mode 100644 index 00000000000..a46e8c9d1a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip new file mode 100644 index 00000000000..20ed87ac5b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip new file mode 100644 index 00000000000..caf989f9a71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip new file mode 100644 index 00000000000..9c7f4c738f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip new file mode 100644 index 00000000000..86d4715e4ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip new file mode 100644 index 00000000000..064f6e873ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip new file mode 100644 index 00000000000..8f71d31bf6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip new file mode 100644 index 00000000000..64a88d00c1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip new file mode 100644 index 00000000000..c844b00596f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip new file mode 100644 index 00000000000..3df31cb5546 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip new file mode 100644 index 00000000000..e643421efbc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip new file mode 100644 index 00000000000..1147254334f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip new file mode 100644 index 00000000000..4e7b609594d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip new file mode 100644 index 00000000000..6797e80b7a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip new file mode 100644 index 00000000000..c3fe918835d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip new file mode 100644 index 00000000000..59421208ec8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip new file mode 100644 index 00000000000..2b9311b9837 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip new file mode 100644 index 00000000000..ea187241a92 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip new file mode 100644 index 00000000000..1ad8d796b0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip new file mode 100644 index 00000000000..11228066f77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip new file mode 100644 index 00000000000..c0b760b0640 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip new file mode 100644 index 00000000000..1ada930929e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip new file mode 100644 index 00000000000..1e3eea6a16b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip new file mode 100644 index 00000000000..59a4cfac8c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip new file mode 100644 index 00000000000..73ed5ba2bcd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip new file mode 100644 index 00000000000..37f7b6d970f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip new file mode 100644 index 00000000000..56114de5ec3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip new file mode 100644 index 00000000000..f3e9617746e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip new file mode 100644 index 00000000000..107eaac4bb6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip new file mode 100644 index 00000000000..d5268a6cfed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip new file mode 100644 index 00000000000..a9bf4ac8966 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip new file mode 100644 index 00000000000..bd58ae5b073 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip new file mode 100644 index 00000000000..d7dcbfd10a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip new file mode 100644 index 00000000000..9a308c4fac7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip new file mode 100644 index 00000000000..e9eb60654f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip new file mode 100644 index 00000000000..314cb5f5ad7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip new file mode 100644 index 00000000000..d39bc8d5dd3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip new file mode 100644 index 00000000000..3aae70aeb0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip new file mode 100644 index 00000000000..f3205430ef1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip new file mode 100644 index 00000000000..7f03548932d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip new file mode 100644 index 00000000000..c5ef2de1b23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip new file mode 100644 index 00000000000..c90b78d5b36 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip new file mode 100644 index 00000000000..16f8f996144 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip new file mode 100644 index 00000000000..ecf63208bb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip new file mode 100644 index 00000000000..ae319eab4af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip new file mode 100644 index 00000000000..7d29797cc8a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip new file mode 100644 index 00000000000..71d6d758ce8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip new file mode 100644 index 00000000000..93eb241144d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip new file mode 100644 index 00000000000..09dcbd26f3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip new file mode 100644 index 00000000000..9aabd72f0b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip new file mode 100644 index 00000000000..6aef15d3c90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip new file mode 100644 index 00000000000..eb2b20126c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip new file mode 100644 index 00000000000..951a884ab1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip new file mode 100644 index 00000000000..d98041e8984 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip new file mode 100644 index 00000000000..81b42c3c796 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip new file mode 100644 index 00000000000..6c9bf31e43d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip new file mode 100644 index 00000000000..258b8f41131 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip new file mode 100644 index 00000000000..b913b516021 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip new file mode 100644 index 00000000000..c0083ea217a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip new file mode 100644 index 00000000000..a6b28615b17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip new file mode 100644 index 00000000000..97f8d43f7d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip new file mode 100644 index 00000000000..7203357b70c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip new file mode 100644 index 00000000000..60c5c65b11f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip new file mode 100644 index 00000000000..6ec78cd0c6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip new file mode 100644 index 00000000000..2972744f96e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip new file mode 100644 index 00000000000..ee857b9105f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip new file mode 100644 index 00000000000..052d1ed21af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip new file mode 100644 index 00000000000..ba0aaab0870 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip new file mode 100644 index 00000000000..17b8fd45e9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip new file mode 100644 index 00000000000..2fc5358c55f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip new file mode 100644 index 00000000000..b303c35613d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip new file mode 100644 index 00000000000..96a5e0049cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip new file mode 100644 index 00000000000..d10457b256f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip new file mode 100644 index 00000000000..8835123a681 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip new file mode 100644 index 00000000000..0f2bc69cb2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip new file mode 100644 index 00000000000..5506cd00149 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip new file mode 100644 index 00000000000..65b9b5c79d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip new file mode 100644 index 00000000000..58d654ccc04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip new file mode 100644 index 00000000000..bd3ee7d87a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip new file mode 100644 index 00000000000..5c80eada962 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip new file mode 100644 index 00000000000..35204ee02cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip new file mode 100644 index 00000000000..4c97ae6a198 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip new file mode 100644 index 00000000000..79752788020 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip new file mode 100644 index 00000000000..db765120f69 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip new file mode 100644 index 00000000000..bb8407f16c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip new file mode 100644 index 00000000000..37f3d2facd3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip new file mode 100644 index 00000000000..28ebbfe667b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip new file mode 100644 index 00000000000..9856101cf81 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip new file mode 100644 index 00000000000..62f6798d224 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip new file mode 100644 index 00000000000..a8c34655cd7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip new file mode 100644 index 00000000000..4b2292110fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip new file mode 100644 index 00000000000..1ea5c1d12a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip new file mode 100644 index 00000000000..471ba97a9da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip new file mode 100644 index 00000000000..b6e08fe9c10 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip new file mode 100644 index 00000000000..3ecb92ff3a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip new file mode 100644 index 00000000000..4fd221d7dca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip new file mode 100644 index 00000000000..28862392bac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip new file mode 100644 index 00000000000..b7b0813e1bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip new file mode 100644 index 00000000000..941581b5be8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip new file mode 100644 index 00000000000..51026d28273 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip new file mode 100644 index 00000000000..25945a70f72 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip new file mode 100644 index 00000000000..5a4038d810b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip new file mode 100644 index 00000000000..b20806d0ece --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip new file mode 100644 index 00000000000..405ee3c4d1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip new file mode 100644 index 00000000000..fe11bb581ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip new file mode 100644 index 00000000000..79d5dd08b05 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip new file mode 100644 index 00000000000..f18be3af4cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip new file mode 100644 index 00000000000..d834040de4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip new file mode 100644 index 00000000000..d98aeafb88f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip new file mode 100644 index 00000000000..579a23f6092 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip new file mode 100644 index 00000000000..c4be9d2eb58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip new file mode 100644 index 00000000000..a28fce33b7b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip new file mode 100644 index 00000000000..3baa424043e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip new file mode 100644 index 00000000000..85adc2a66a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip new file mode 100644 index 00000000000..e2bbd10801e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip new file mode 100644 index 00000000000..832b29ff44e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip new file mode 100644 index 00000000000..252c8f7d0dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip new file mode 100644 index 00000000000..01c9d16f3a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip new file mode 100644 index 00000000000..28017c42ec7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip new file mode 100644 index 00000000000..3e13537d3d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip new file mode 100644 index 00000000000..5423fe85c9d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip new file mode 100644 index 00000000000..f5c10ce4d21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip new file mode 100644 index 00000000000..5707e001982 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip new file mode 100644 index 00000000000..2b705c532a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip new file mode 100644 index 00000000000..b01539a5d08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip new file mode 100644 index 00000000000..c16fef433fe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip new file mode 100644 index 00000000000..c50336b6148 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip new file mode 100644 index 00000000000..3cbcfca0c77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip new file mode 100644 index 00000000000..9ef9694e78f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip new file mode 100644 index 00000000000..1e1bdfb1bd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip new file mode 100644 index 00000000000..ae87a4aa231 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip new file mode 100644 index 00000000000..b146cf40845 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip new file mode 100644 index 00000000000..8b7074e9990 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip new file mode 100644 index 00000000000..cf727977b73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip new file mode 100644 index 00000000000..1effa7fb7fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip new file mode 100644 index 00000000000..fb3ddc019fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip new file mode 100644 index 00000000000..ca43e2867ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip new file mode 100644 index 00000000000..0f882214410 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip new file mode 100644 index 00000000000..181b498ce72 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip new file mode 100644 index 00000000000..9bba041444d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip new file mode 100644 index 00000000000..cee13840f3d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip new file mode 100644 index 00000000000..b9fa502c84d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip new file mode 100644 index 00000000000..93a77a8b9d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip new file mode 100644 index 00000000000..b2b77825078 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip new file mode 100644 index 00000000000..dfc3921c02e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip new file mode 100644 index 00000000000..edfcfeb1cc3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip new file mode 100644 index 00000000000..e8423351ee7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip new file mode 100644 index 00000000000..5a047976a5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip new file mode 100644 index 00000000000..62579cdaf7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip new file mode 100644 index 00000000000..ce79ac4a633 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip new file mode 100644 index 00000000000..ca718c6ce1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip new file mode 100644 index 00000000000..147cdc179f1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip new file mode 100644 index 00000000000..555847bb1d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip new file mode 100644 index 00000000000..35207bc7115 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip new file mode 100644 index 00000000000..b445dda9b5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip new file mode 100644 index 00000000000..0e7b5e2b256 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip new file mode 100644 index 00000000000..3b96cb2f6da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip new file mode 100644 index 00000000000..7fa113622d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip new file mode 100644 index 00000000000..14b1e581ef1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip new file mode 100644 index 00000000000..81f8717ec47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip new file mode 100644 index 00000000000..d847abfc5e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip new file mode 100644 index 00000000000..cfa83bac96b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip new file mode 100644 index 00000000000..4ac9b91617a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip new file mode 100644 index 00000000000..c01aa886d33 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip new file mode 100644 index 00000000000..36ef130d401 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip new file mode 100644 index 00000000000..60a76667602 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip new file mode 100644 index 00000000000..cd262a1319b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip new file mode 100644 index 00000000000..b1bf4ec0f3c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip new file mode 100644 index 00000000000..741db9df0c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip new file mode 100644 index 00000000000..603e0872921 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip new file mode 100644 index 00000000000..b93e01af0ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip new file mode 100644 index 00000000000..e693144bf20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip new file mode 100644 index 00000000000..c3275fe1336 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip new file mode 100644 index 00000000000..56793e3969e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip new file mode 100644 index 00000000000..200e2791067 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip new file mode 100644 index 00000000000..ad200b3bc3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip new file mode 100644 index 00000000000..2022023dd50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip new file mode 100644 index 00000000000..9738d2e4d62 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip new file mode 100644 index 00000000000..4a409a802f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip new file mode 100644 index 00000000000..affd984987d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip new file mode 100644 index 00000000000..ec8213154f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip new file mode 100644 index 00000000000..b667ec694e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip new file mode 100644 index 00000000000..c13bfb9dd92 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip new file mode 100644 index 00000000000..73fbe5ebab1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip new file mode 100644 index 00000000000..80f364bf21b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip new file mode 100644 index 00000000000..9184e7f08e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip new file mode 100644 index 00000000000..222afbb4e77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip new file mode 100644 index 00000000000..5526e7868ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip new file mode 100644 index 00000000000..0d823727256 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip new file mode 100644 index 00000000000..89319da79dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip new file mode 100644 index 00000000000..a2c9587f665 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip new file mode 100644 index 00000000000..721a61bf546 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip new file mode 100644 index 00000000000..56d732bc792 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip new file mode 100644 index 00000000000..81f03941a6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip new file mode 100644 index 00000000000..041654a0473 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip new file mode 100644 index 00000000000..977dd8080e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip new file mode 100644 index 00000000000..05280ff531b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip new file mode 100644 index 00000000000..7e2415b296c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip new file mode 100644 index 00000000000..72b7af002ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip new file mode 100644 index 00000000000..7ba6eedf761 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip new file mode 100644 index 00000000000..9c6f0c8905c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip new file mode 100644 index 00000000000..68838c8d32e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip new file mode 100644 index 00000000000..93f3475e6e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip new file mode 100644 index 00000000000..31c94c7f898 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip new file mode 100644 index 00000000000..ef7c125928f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip new file mode 100644 index 00000000000..f8a361ce637 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip new file mode 100644 index 00000000000..886343649f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip new file mode 100644 index 00000000000..32f2eb1c249 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip new file mode 100644 index 00000000000..98bc5dde59e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip new file mode 100644 index 00000000000..255a436c7d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip new file mode 100644 index 00000000000..97b29f4190d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip new file mode 100644 index 00000000000..869517b0f73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip new file mode 100644 index 00000000000..d4b65bf1602 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip new file mode 100644 index 00000000000..60a12bd2f99 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip new file mode 100644 index 00000000000..e493c2e5207 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip new file mode 100644 index 00000000000..a171deece2a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip new file mode 100644 index 00000000000..3ce9605b6ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip new file mode 100644 index 00000000000..2d6b7f56d24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip new file mode 100644 index 00000000000..6ee5fecbdfb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip new file mode 100644 index 00000000000..740a14b3896 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip new file mode 100644 index 00000000000..e8f7155fdb0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip new file mode 100644 index 00000000000..668edd0effc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip new file mode 100644 index 00000000000..a0ae88c8061 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip new file mode 100644 index 00000000000..742463346af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip new file mode 100644 index 00000000000..df68604c46f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip new file mode 100644 index 00000000000..2c57edb1b30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip new file mode 100644 index 00000000000..5f901c56ff4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip new file mode 100644 index 00000000000..44e9c2e4c6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip new file mode 100644 index 00000000000..3f856fea1f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip new file mode 100644 index 00000000000..bd032e27668 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip new file mode 100644 index 00000000000..ef304c578a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip new file mode 100644 index 00000000000..564d8374617 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip new file mode 100644 index 00000000000..e5d2004adf2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip new file mode 100644 index 00000000000..17be14af25b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip new file mode 100644 index 00000000000..b90714501b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip new file mode 100644 index 00000000000..06a07db689a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip new file mode 100644 index 00000000000..b5b601d7d53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip new file mode 100644 index 00000000000..1ea30c92f14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip new file mode 100644 index 00000000000..c348e217f75 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip new file mode 100644 index 00000000000..ef460fe8a8c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip new file mode 100644 index 00000000000..611502df731 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip new file mode 100644 index 00000000000..d95a9874bd8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip new file mode 100644 index 00000000000..48f4d0aed93 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip new file mode 100644 index 00000000000..de5126d11e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip new file mode 100644 index 00000000000..23af435c18b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip new file mode 100644 index 00000000000..d37bf9ec4a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip new file mode 100644 index 00000000000..cb1531a3738 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip new file mode 100644 index 00000000000..2f3395e4f40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip new file mode 100644 index 00000000000..3b17343a968 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip new file mode 100644 index 00000000000..d132dabefd0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip new file mode 100644 index 00000000000..7e5865723fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip new file mode 100644 index 00000000000..86c405bdf4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip new file mode 100644 index 00000000000..d5d7249e38f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip new file mode 100644 index 00000000000..cca0753895f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip new file mode 100644 index 00000000000..63cc76c2bc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip new file mode 100644 index 00000000000..69d86ffa6cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip new file mode 100644 index 00000000000..b53e8a7fea8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip new file mode 100644 index 00000000000..4311527376a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip new file mode 100644 index 00000000000..49df5646dcc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip new file mode 100644 index 00000000000..1f44e17f5f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip new file mode 100644 index 00000000000..10e961bddae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip new file mode 100644 index 00000000000..e68e0004da8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip new file mode 100644 index 00000000000..2631aea8563 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip new file mode 100644 index 00000000000..81a62ac01ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip new file mode 100644 index 00000000000..f9b3c603639 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip new file mode 100644 index 00000000000..f626aec3121 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip new file mode 100644 index 00000000000..747b26b3171 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip new file mode 100644 index 00000000000..56d7d4ab9bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip new file mode 100644 index 00000000000..e3c6e640275 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip new file mode 100644 index 00000000000..b0136c22260 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip new file mode 100644 index 00000000000..4d88eeb5a8f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip new file mode 100644 index 00000000000..d500d2e7da3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip new file mode 100644 index 00000000000..7af5f97ebab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip new file mode 100644 index 00000000000..20805e32fa9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip new file mode 100644 index 00000000000..4a32d7b4c69 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip new file mode 100644 index 00000000000..ee406c5d6d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip new file mode 100644 index 00000000000..61e7e816ce7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip new file mode 100644 index 00000000000..b6c5e9518a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip new file mode 100644 index 00000000000..90aee59be11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip new file mode 100644 index 00000000000..1f369ec2839 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip new file mode 100644 index 00000000000..2cc6aa370a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip new file mode 100644 index 00000000000..83887203cff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip new file mode 100644 index 00000000000..28da9b996aa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip new file mode 100644 index 00000000000..e5f75f10395 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip new file mode 100644 index 00000000000..2f0c3065dd2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip new file mode 100644 index 00000000000..6b05a1501a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip new file mode 100644 index 00000000000..42776dda70c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip new file mode 100644 index 00000000000..fe7b3a8ea30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip new file mode 100644 index 00000000000..642ee740a65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip new file mode 100644 index 00000000000..ebfb3e4ce54 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip new file mode 100644 index 00000000000..6c1dfc913dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip new file mode 100644 index 00000000000..a66c9b2f45a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip new file mode 100644 index 00000000000..422fabd80a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip new file mode 100644 index 00000000000..8ca426d1bea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip new file mode 100644 index 00000000000..d9041d10891 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip new file mode 100644 index 00000000000..96c9a760bbd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip new file mode 100644 index 00000000000..74b3ade1218 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip new file mode 100644 index 00000000000..a3b036096ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip new file mode 100644 index 00000000000..b2bfb22e16e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip new file mode 100644 index 00000000000..a9897db9228 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip new file mode 100644 index 00000000000..cb223c4050b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip new file mode 100644 index 00000000000..1f04d3a1af9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip new file mode 100644 index 00000000000..ac9d48c4fb0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip new file mode 100644 index 00000000000..7cd1eab7490 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip new file mode 100644 index 00000000000..b3292ab6d77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip new file mode 100644 index 00000000000..45a342f057d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip new file mode 100644 index 00000000000..a14bb66e2b5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip new file mode 100644 index 00000000000..ad4263d4eb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip new file mode 100644 index 00000000000..adef882533f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip new file mode 100644 index 00000000000..d2e731ca555 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip new file mode 100644 index 00000000000..9b994bab809 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip new file mode 100644 index 00000000000..61f063c3839 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip new file mode 100644 index 00000000000..de24e5557f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip new file mode 100644 index 00000000000..e94aa1fca6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip new file mode 100644 index 00000000000..682bef8a507 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip new file mode 100644 index 00000000000..25b797b2348 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip new file mode 100644 index 00000000000..ed783900cfd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip new file mode 100644 index 00000000000..2dc0c8d0afe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip new file mode 100644 index 00000000000..74ef40ca530 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip new file mode 100644 index 00000000000..ef246224837 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip new file mode 100644 index 00000000000..16488c9a130 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip new file mode 100644 index 00000000000..af93b92d570 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip new file mode 100644 index 00000000000..03cbd920191 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip new file mode 100644 index 00000000000..b10c84772f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip new file mode 100644 index 00000000000..ee228469413 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip new file mode 100644 index 00000000000..7684341efb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip new file mode 100644 index 00000000000..095b4f99e34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip new file mode 100644 index 00000000000..f866e39e0f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip new file mode 100644 index 00000000000..1f7c65978ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip new file mode 100644 index 00000000000..27a174e9f93 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip new file mode 100644 index 00000000000..62cad16726a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip new file mode 100644 index 00000000000..3c572aee9c4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip new file mode 100644 index 00000000000..22c7abd5ac6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip new file mode 100644 index 00000000000..55938f955b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip new file mode 100644 index 00000000000..cbf40ea7340 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip new file mode 100644 index 00000000000..a51b705cf20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip new file mode 100644 index 00000000000..1d0fb66b1ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip new file mode 100644 index 00000000000..a010fddc490 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip new file mode 100644 index 00000000000..6c74184d0c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip new file mode 100644 index 00000000000..eda2b272b9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip new file mode 100644 index 00000000000..475bca57c36 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip new file mode 100644 index 00000000000..f116c775fa8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip new file mode 100644 index 00000000000..972b36fe16a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip new file mode 100644 index 00000000000..eef8d609f64 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip new file mode 100644 index 00000000000..c5675d69691 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip new file mode 100644 index 00000000000..7a41e134c84 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip new file mode 100644 index 00000000000..ba1011e9d6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip new file mode 100644 index 00000000000..5f9078186d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip new file mode 100644 index 00000000000..7ab5d7d1705 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip new file mode 100644 index 00000000000..ced9c424f09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip new file mode 100644 index 00000000000..7b63fe2ccfb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip new file mode 100644 index 00000000000..4440b7e6504 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip new file mode 100644 index 00000000000..1e13fa7afc0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip new file mode 100644 index 00000000000..d2ae7ae45fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip new file mode 100644 index 00000000000..908e7ff0b5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip new file mode 100644 index 00000000000..4501cb318fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip new file mode 100644 index 00000000000..6da3c2e6000 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip new file mode 100644 index 00000000000..0e4e84ffe0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip new file mode 100644 index 00000000000..1c77f540b68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip new file mode 100644 index 00000000000..814ad8ec592 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip new file mode 100644 index 00000000000..923593adb17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip new file mode 100644 index 00000000000..dd4fd854092 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip new file mode 100644 index 00000000000..734603126ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip new file mode 100644 index 00000000000..873f1b44ded --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip new file mode 100644 index 00000000000..026413f14c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip new file mode 100644 index 00000000000..556c9bd087b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip new file mode 100644 index 00000000000..759e5ee3190 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip new file mode 100644 index 00000000000..9038cd3adf4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip new file mode 100644 index 00000000000..dbc2e060f86 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip new file mode 100644 index 00000000000..dcc79ee6a0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip new file mode 100644 index 00000000000..ca040732924 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip new file mode 100644 index 00000000000..786928b952c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip new file mode 100644 index 00000000000..4f7978f8ab6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip new file mode 100644 index 00000000000..cc7bb4ca2c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip new file mode 100644 index 00000000000..2021cb3b948 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip new file mode 100644 index 00000000000..ef65aa0cbe9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip new file mode 100644 index 00000000000..50a0355bf06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip new file mode 100644 index 00000000000..8594f317343 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip new file mode 100644 index 00000000000..21bb0448971 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip new file mode 100644 index 00000000000..cb1ec3c5f7e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip new file mode 100644 index 00000000000..1b1bf13c290 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip new file mode 100644 index 00000000000..d32f58de2c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip new file mode 100644 index 00000000000..ce373394af3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip new file mode 100644 index 00000000000..1bb18015aa1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip new file mode 100644 index 00000000000..ab8f2aeef41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip new file mode 100644 index 00000000000..2c75f779939 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip new file mode 100644 index 00000000000..063c77a4549 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip new file mode 100644 index 00000000000..152d98ac85d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip new file mode 100644 index 00000000000..63e59aa9da0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip new file mode 100644 index 00000000000..8c622562906 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip new file mode 100644 index 00000000000..78a4ab4e6e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip new file mode 100644 index 00000000000..4fadf647940 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip new file mode 100644 index 00000000000..43bcd219961 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip new file mode 100644 index 00000000000..704c87b0a9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip new file mode 100644 index 00000000000..3ecabacb7f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip new file mode 100644 index 00000000000..65076c99dd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip new file mode 100644 index 00000000000..bb635d6ff82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip new file mode 100644 index 00000000000..f2c7385f46d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip new file mode 100644 index 00000000000..1d331ff1972 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip new file mode 100644 index 00000000000..f101ae9fe34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip new file mode 100644 index 00000000000..a8b9243487d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip new file mode 100644 index 00000000000..53302ce3d97 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip new file mode 100644 index 00000000000..b3d0ad587cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip new file mode 100644 index 00000000000..79ab1cd43cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip new file mode 100644 index 00000000000..ca4f0e8d462 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip new file mode 100644 index 00000000000..57ddba22487 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip new file mode 100644 index 00000000000..e177ef281b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip new file mode 100644 index 00000000000..e2376679103 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip new file mode 100644 index 00000000000..8fe8fbf9b2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip new file mode 100644 index 00000000000..63a0c899b89 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip new file mode 100644 index 00000000000..263054dca17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip new file mode 100644 index 00000000000..3c95ad357d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip new file mode 100644 index 00000000000..c12ba4da6d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip new file mode 100644 index 00000000000..82d78dccfa0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip new file mode 100644 index 00000000000..a3cda689646 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip new file mode 100644 index 00000000000..b2b9004457b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip new file mode 100644 index 00000000000..401c9980bc0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip new file mode 100644 index 00000000000..288085b7f03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip new file mode 100644 index 00000000000..d93a05b6a55 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip new file mode 100644 index 00000000000..b154cf2e3a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip new file mode 100644 index 00000000000..5eb4f26ae7b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip new file mode 100644 index 00000000000..13506066037 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip new file mode 100644 index 00000000000..da1632ff5de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip new file mode 100644 index 00000000000..5a974dfe44d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip new file mode 100644 index 00000000000..7217aff87c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip new file mode 100644 index 00000000000..49834b64050 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip new file mode 100644 index 00000000000..1106b708d9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip new file mode 100644 index 00000000000..29a53b80166 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip new file mode 100644 index 00000000000..5c76e72928b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip new file mode 100644 index 00000000000..9f041ce8f56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip new file mode 100644 index 00000000000..de2940480d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip new file mode 100644 index 00000000000..53aaa3daba9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip new file mode 100644 index 00000000000..553876b0ac3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip new file mode 100644 index 00000000000..fca384536e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip new file mode 100644 index 00000000000..83f54dcbfa2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip new file mode 100644 index 00000000000..e2ead7d8b1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip new file mode 100644 index 00000000000..1d2ace1439b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip new file mode 100644 index 00000000000..d40fb018bb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip new file mode 100644 index 00000000000..19561dc031b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip new file mode 100644 index 00000000000..6f2730e273f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip new file mode 100644 index 00000000000..2153f8dcc77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip new file mode 100644 index 00000000000..ade8b04d9ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip new file mode 100644 index 00000000000..2114626fba4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip new file mode 100644 index 00000000000..feda9895fa6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip new file mode 100644 index 00000000000..f940a11d4eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip new file mode 100644 index 00000000000..ab1e7651e61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip new file mode 100644 index 00000000000..866d9f28315 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip new file mode 100644 index 00000000000..cd987b34952 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip new file mode 100644 index 00000000000..74201ad912e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip new file mode 100644 index 00000000000..2641e736fc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip new file mode 100644 index 00000000000..b78f6fbd44b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip new file mode 100644 index 00000000000..7974289e9b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip new file mode 100644 index 00000000000..3eebf99d1a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip new file mode 100644 index 00000000000..cb73d31ebaf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip new file mode 100644 index 00000000000..f44a100c7dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip new file mode 100644 index 00000000000..6bd087d8c11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip new file mode 100644 index 00000000000..601f2e7e5c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip new file mode 100644 index 00000000000..77c01f585a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip new file mode 100644 index 00000000000..f0add1a5b8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip new file mode 100644 index 00000000000..e45d80c6a56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip new file mode 100644 index 00000000000..e062a838409 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip new file mode 100644 index 00000000000..98a8d128207 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip new file mode 100644 index 00000000000..78329df0c5c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip new file mode 100644 index 00000000000..bba679dadf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip new file mode 100644 index 00000000000..efde34b2f2c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip new file mode 100644 index 00000000000..31f86f982f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip new file mode 100644 index 00000000000..71649189b5d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip new file mode 100644 index 00000000000..dd74f1bd636 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip new file mode 100644 index 00000000000..73465e5777d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip new file mode 100644 index 00000000000..7b05991667c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip new file mode 100644 index 00000000000..e1e538835e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip new file mode 100644 index 00000000000..2ab79463c5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip new file mode 100644 index 00000000000..46e0c94e86b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip new file mode 100644 index 00000000000..17d1bb9530e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip new file mode 100644 index 00000000000..44eb1986a66 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip new file mode 100644 index 00000000000..e793e94adb6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip new file mode 100644 index 00000000000..3a3b4e69a88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip new file mode 100644 index 00000000000..bf2296db500 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip new file mode 100644 index 00000000000..0642dace348 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip new file mode 100644 index 00000000000..4570446797c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip new file mode 100644 index 00000000000..7ecbad53f90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip new file mode 100644 index 00000000000..09b7dce1884 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip new file mode 100644 index 00000000000..e98471e8dc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip new file mode 100644 index 00000000000..5ff5eb5428e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip new file mode 100644 index 00000000000..2df4bb29523 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip new file mode 100644 index 00000000000..30c74630350 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip new file mode 100644 index 00000000000..4d052ab3db8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip new file mode 100644 index 00000000000..10985b37d79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip new file mode 100644 index 00000000000..489e0d778e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip new file mode 100644 index 00000000000..3e858ad2835 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip new file mode 100644 index 00000000000..c9c1b381324 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip new file mode 100644 index 00000000000..82d7bcf2cc1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip new file mode 100644 index 00000000000..0f7bab218f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip new file mode 100644 index 00000000000..707470e0d31 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip new file mode 100644 index 00000000000..99dbd992dfb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip new file mode 100644 index 00000000000..412714e9cc9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip new file mode 100644 index 00000000000..7297f13a93d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip new file mode 100644 index 00000000000..6a4b0fd85e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip new file mode 100644 index 00000000000..0c54a98de49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip new file mode 100644 index 00000000000..d3e40eae7c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip new file mode 100644 index 00000000000..046c905dcc3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip new file mode 100644 index 00000000000..a3a3500bf70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip new file mode 100644 index 00000000000..e840e2f86c4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip new file mode 100644 index 00000000000..06df0b0f647 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip new file mode 100644 index 00000000000..efefcc88dbf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip new file mode 100644 index 00000000000..c36699649bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip new file mode 100644 index 00000000000..908b3eebd84 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip new file mode 100644 index 00000000000..3f5d67fb6bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip new file mode 100644 index 00000000000..10aef5450c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip new file mode 100644 index 00000000000..b7ff8d40e9a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip new file mode 100644 index 00000000000..48a337f18f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip new file mode 100644 index 00000000000..67d289a046a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip new file mode 100644 index 00000000000..4981b113d3f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip new file mode 100644 index 00000000000..e4cb72b31f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip new file mode 100644 index 00000000000..d93079f7d8d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip new file mode 100644 index 00000000000..5f0d722e5f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip new file mode 100644 index 00000000000..5792beb0f34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip new file mode 100644 index 00000000000..25ecd7ae6e3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip new file mode 100644 index 00000000000..df6ef5ff27d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip new file mode 100644 index 00000000000..b02ddff6364 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip new file mode 100644 index 00000000000..ae5eb09a866 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip new file mode 100644 index 00000000000..5e4c5035b79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip new file mode 100644 index 00000000000..0e59e52fbe8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip new file mode 100644 index 00000000000..4ecc422eb58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip new file mode 100644 index 00000000000..daee5524a5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip new file mode 100644 index 00000000000..14419e0463f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip new file mode 100644 index 00000000000..29702caa878 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip new file mode 100644 index 00000000000..cb0d58a3c82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip new file mode 100644 index 00000000000..4689a93c01b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip new file mode 100644 index 00000000000..97ed99b87ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip new file mode 100644 index 00000000000..b8ec64735f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip new file mode 100644 index 00000000000..7bf4ca0ae48 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip new file mode 100644 index 00000000000..8354935ff71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip new file mode 100644 index 00000000000..b1c013312fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip new file mode 100644 index 00000000000..60b4fff0c63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip new file mode 100644 index 00000000000..94e2e5f888a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip new file mode 100644 index 00000000000..6fafb6b6150 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip new file mode 100644 index 00000000000..df12fb8188b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip new file mode 100644 index 00000000000..c457e60f89b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip new file mode 100644 index 00000000000..7856926c691 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip new file mode 100644 index 00000000000..dc055325486 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip new file mode 100644 index 00000000000..e62cc9116e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip new file mode 100644 index 00000000000..70a2897b4a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip new file mode 100644 index 00000000000..f83afc6ed97 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip new file mode 100644 index 00000000000..1f003f7d4fe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip new file mode 100644 index 00000000000..b00cacf498f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip new file mode 100644 index 00000000000..939fe306551 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip new file mode 100644 index 00000000000..ac8ba2d9319 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip new file mode 100644 index 00000000000..ec55820cb67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip new file mode 100644 index 00000000000..622ee71147a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip new file mode 100644 index 00000000000..c722c77e259 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip new file mode 100644 index 00000000000..c176d406679 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip new file mode 100644 index 00000000000..60cc872617d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip new file mode 100644 index 00000000000..12c685195aa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip new file mode 100644 index 00000000000..77a4df5b652 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip new file mode 100644 index 00000000000..916f9471868 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip new file mode 100644 index 00000000000..f7ff6b9b35f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip new file mode 100644 index 00000000000..1f320e33ac5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip new file mode 100644 index 00000000000..1651f3cb839 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip new file mode 100644 index 00000000000..dbe426a1830 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip new file mode 100644 index 00000000000..2cac2fc3b4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip new file mode 100644 index 00000000000..05d80da7524 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip new file mode 100644 index 00000000000..bd5640a71a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip new file mode 100644 index 00000000000..5aefb14dcc0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip new file mode 100644 index 00000000000..a5ccdcb12c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip new file mode 100644 index 00000000000..dc448a1795d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip new file mode 100644 index 00000000000..6915e2071af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip new file mode 100644 index 00000000000..76d83307c7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip new file mode 100644 index 00000000000..708b4cee492 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip new file mode 100644 index 00000000000..e70029e1363 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip new file mode 100644 index 00000000000..0d0e15072ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip new file mode 100644 index 00000000000..1148242e7bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip new file mode 100644 index 00000000000..4224bf6dcfa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip new file mode 100644 index 00000000000..f9007c8d0a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip new file mode 100644 index 00000000000..3af4858be13 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip new file mode 100644 index 00000000000..0412b0d6b1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip new file mode 100644 index 00000000000..9c8c668bd2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip new file mode 100644 index 00000000000..e5e19dc68d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip new file mode 100644 index 00000000000..50a898f1adb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip new file mode 100644 index 00000000000..535f441e9d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip new file mode 100644 index 00000000000..568fb2a61a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip new file mode 100644 index 00000000000..b1a20313826 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip new file mode 100644 index 00000000000..46e3bb80c18 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip new file mode 100644 index 00000000000..82f28e19893 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip new file mode 100644 index 00000000000..e2273c31b33 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip new file mode 100644 index 00000000000..44950e6110f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip new file mode 100644 index 00000000000..923f41381c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip new file mode 100644 index 00000000000..9a670ef26e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip new file mode 100644 index 00000000000..32b5404e47e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip new file mode 100644 index 00000000000..3998da16612 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip new file mode 100644 index 00000000000..957890d6bce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip new file mode 100644 index 00000000000..0d423090af9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip new file mode 100644 index 00000000000..c027189593b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip new file mode 100644 index 00000000000..a8c37b18586 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip new file mode 100644 index 00000000000..ea9eb74433c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip new file mode 100644 index 00000000000..0c2d7bb88b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip new file mode 100644 index 00000000000..1eb209d714a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip new file mode 100644 index 00000000000..19c0a9cc05b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip new file mode 100644 index 00000000000..1671e0dd893 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip new file mode 100644 index 00000000000..ad7cb889666 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip new file mode 100644 index 00000000000..10f202afe22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip new file mode 100644 index 00000000000..6d508d109ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip new file mode 100644 index 00000000000..c36d88e3124 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip new file mode 100644 index 00000000000..e27e532203d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip new file mode 100644 index 00000000000..6488aeb1bc3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip new file mode 100644 index 00000000000..62623e588e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip new file mode 100644 index 00000000000..ea3a4fab391 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip new file mode 100644 index 00000000000..d826af15042 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip new file mode 100644 index 00000000000..288c0b2167d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip new file mode 100644 index 00000000000..26ec43c91b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip new file mode 100644 index 00000000000..0e201d26453 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip new file mode 100644 index 00000000000..ed60152f8ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip new file mode 100644 index 00000000000..55e65309e5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip new file mode 100644 index 00000000000..8f6303a1c6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip new file mode 100644 index 00000000000..e313cb363ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip new file mode 100644 index 00000000000..0a750314063 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip new file mode 100644 index 00000000000..32a61c6a29c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip new file mode 100644 index 00000000000..e102afa7c30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip new file mode 100644 index 00000000000..697f8d6f393 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip new file mode 100644 index 00000000000..0c0d4e1400a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip new file mode 100644 index 00000000000..b9c554b89b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip new file mode 100644 index 00000000000..ea4a38d7df1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip new file mode 100644 index 00000000000..5f464dfd0ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip new file mode 100644 index 00000000000..b2ea0970aa6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip new file mode 100644 index 00000000000..96a10e08380 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip new file mode 100644 index 00000000000..8db90faeefc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip new file mode 100644 index 00000000000..b0da0f98357 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip new file mode 100644 index 00000000000..cde5025b204 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip new file mode 100644 index 00000000000..f52451423cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip new file mode 100644 index 00000000000..c75edd4b648 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip new file mode 100644 index 00000000000..c96d54733ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip new file mode 100644 index 00000000000..a10f2dcd0b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip new file mode 100644 index 00000000000..ffb2ee41786 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip new file mode 100644 index 00000000000..61f98c5dbad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip new file mode 100644 index 00000000000..acdbebfd5c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip new file mode 100644 index 00000000000..48d33029a3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip new file mode 100644 index 00000000000..bd8ca3a01d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip new file mode 100644 index 00000000000..ea4b4b24aab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip new file mode 100644 index 00000000000..d7575f04ec4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip new file mode 100644 index 00000000000..6a2de30bb01 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip new file mode 100644 index 00000000000..aeb274f09f1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip new file mode 100644 index 00000000000..f32a6421d83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip new file mode 100644 index 00000000000..f941e72f617 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip new file mode 100644 index 00000000000..64bacecc2cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip new file mode 100644 index 00000000000..1dcb86091fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip new file mode 100644 index 00000000000..59c5142497f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip new file mode 100644 index 00000000000..ad15084e4cc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip new file mode 100644 index 00000000000..58864650316 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip new file mode 100644 index 00000000000..2780b0d568b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip new file mode 100644 index 00000000000..a1619cab297 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip new file mode 100644 index 00000000000..0957d6a9b1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip new file mode 100644 index 00000000000..46512ea2b5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip new file mode 100644 index 00000000000..389830690d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip new file mode 100644 index 00000000000..d33106b0be7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip new file mode 100644 index 00000000000..3e2a3ee96e8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip new file mode 100644 index 00000000000..356a97e25a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip new file mode 100644 index 00000000000..748c01991b5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip new file mode 100644 index 00000000000..54fc606a0bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip new file mode 100644 index 00000000000..f8b56fbd1b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip new file mode 100644 index 00000000000..3374fc97d09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip new file mode 100644 index 00000000000..aa55230f948 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip new file mode 100644 index 00000000000..7050da91098 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip new file mode 100644 index 00000000000..8d8c07d01fe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip new file mode 100644 index 00000000000..e54ab7e0dd0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip new file mode 100644 index 00000000000..c286c16ba83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip new file mode 100644 index 00000000000..621c782e0e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip new file mode 100644 index 00000000000..9d1e6d6551c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip new file mode 100644 index 00000000000..a83be7a1699 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip new file mode 100644 index 00000000000..1da958d9209 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip new file mode 100644 index 00000000000..5917659c09c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip new file mode 100644 index 00000000000..21c71b806ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip new file mode 100644 index 00000000000..73c4cb917df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip new file mode 100644 index 00000000000..d360841e5b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip new file mode 100644 index 00000000000..e7ad1d9e903 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip new file mode 100644 index 00000000000..81aab7917a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip new file mode 100644 index 00000000000..e645a9cf03f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip new file mode 100644 index 00000000000..61bcb0c3a4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip new file mode 100644 index 00000000000..f826edcf7a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip new file mode 100644 index 00000000000..3bf5eafe74b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip new file mode 100644 index 00000000000..ab19c5ef64b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip new file mode 100644 index 00000000000..dc7700e44f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip new file mode 100644 index 00000000000..61d12216a26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip new file mode 100644 index 00000000000..791e9ed6ef1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip new file mode 100644 index 00000000000..bba64793b7b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip new file mode 100644 index 00000000000..53c8d770ad7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip new file mode 100644 index 00000000000..cf4e6b089f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip new file mode 100644 index 00000000000..d20fa308e52 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip new file mode 100644 index 00000000000..7e73f68899b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip new file mode 100644 index 00000000000..9c7295dd911 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip new file mode 100644 index 00000000000..836db9f6e86 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip new file mode 100644 index 00000000000..ea220cf459a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip new file mode 100644 index 00000000000..4f64a8af7d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip new file mode 100644 index 00000000000..13d2416038c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip new file mode 100644 index 00000000000..fa1724ec38c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip new file mode 100644 index 00000000000..6af3d4ddbd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip new file mode 100644 index 00000000000..36817d57ff4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip new file mode 100644 index 00000000000..0ad57e26bb9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip new file mode 100644 index 00000000000..020af5402f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip new file mode 100644 index 00000000000..953d23ed4f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip new file mode 100644 index 00000000000..1ca6dbbccfd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip new file mode 100644 index 00000000000..74c24521183 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip new file mode 100644 index 00000000000..88a7b6c0c0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip new file mode 100644 index 00000000000..1d1d428e5f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip new file mode 100644 index 00000000000..1f709adf33d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip new file mode 100644 index 00000000000..ccbb66cd88d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip new file mode 100644 index 00000000000..c8c411a5257 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip new file mode 100644 index 00000000000..6e76b2218d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip new file mode 100644 index 00000000000..fd9d0e08b9a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip new file mode 100644 index 00000000000..722aae67021 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip new file mode 100644 index 00000000000..00a9f7a7abb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip new file mode 100644 index 00000000000..2585c19fbbc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip new file mode 100644 index 00000000000..d34160b1b16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip new file mode 100644 index 00000000000..1366275e409 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip new file mode 100644 index 00000000000..13d757d9843 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip new file mode 100644 index 00000000000..e50fb8502b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip new file mode 100644 index 00000000000..9ca6fc02b79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip new file mode 100644 index 00000000000..e1be5aea089 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip new file mode 100644 index 00000000000..8d6e7b9ce37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip new file mode 100644 index 00000000000..f22eb9438c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip new file mode 100644 index 00000000000..685fdcccebf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip new file mode 100644 index 00000000000..e7862a10c70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip new file mode 100644 index 00000000000..fcede89ce04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip new file mode 100644 index 00000000000..2cc6c54743d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip new file mode 100644 index 00000000000..f6cee564bae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip new file mode 100644 index 00000000000..8d7782c8a78 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip new file mode 100644 index 00000000000..abef9cc7af6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip new file mode 100644 index 00000000000..10259551d95 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip new file mode 100644 index 00000000000..68e98abb3da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip new file mode 100644 index 00000000000..04562d43e6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip new file mode 100644 index 00000000000..b1a8d1bdc07 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip new file mode 100644 index 00000000000..b797f8dbee0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip new file mode 100644 index 00000000000..695453631ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip new file mode 100644 index 00000000000..300fb059215 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip new file mode 100644 index 00000000000..dd82712c7c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip new file mode 100644 index 00000000000..8c2b34f4884 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip new file mode 100644 index 00000000000..b3fda98f7f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip new file mode 100644 index 00000000000..0fe3d7f3ef6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip new file mode 100644 index 00000000000..55a357335bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip new file mode 100644 index 00000000000..5c029304800 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip new file mode 100644 index 00000000000..63387793002 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip new file mode 100644 index 00000000000..d3ed2d51816 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip new file mode 100644 index 00000000000..79177c2d6fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip new file mode 100644 index 00000000000..d6d2f19f4af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip new file mode 100644 index 00000000000..55519904a49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip new file mode 100644 index 00000000000..4d0c9a2b57f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip new file mode 100644 index 00000000000..9ab334b4303 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip new file mode 100644 index 00000000000..ebfb54a3225 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip new file mode 100644 index 00000000000..84658a15c4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip new file mode 100644 index 00000000000..d40da5d5b20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip new file mode 100644 index 00000000000..6c4df694544 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip new file mode 100644 index 00000000000..1639ece1b20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip new file mode 100644 index 00000000000..b7a2e3905d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip new file mode 100644 index 00000000000..39bdcc20a8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip new file mode 100644 index 00000000000..51260dc0ce9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip new file mode 100644 index 00000000000..78137a331f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip new file mode 100644 index 00000000000..eb744ff3279 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip new file mode 100644 index 00000000000..2307064b178 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip new file mode 100644 index 00000000000..32a84284387 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip new file mode 100644 index 00000000000..9ce46029f56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip new file mode 100644 index 00000000000..77d65db91d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip new file mode 100644 index 00000000000..0411dfdd8e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip new file mode 100644 index 00000000000..051745befa8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip new file mode 100644 index 00000000000..95175844877 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip new file mode 100644 index 00000000000..e9b81569dce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip new file mode 100644 index 00000000000..ac3fce28609 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip new file mode 100644 index 00000000000..940795ee7ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip new file mode 100644 index 00000000000..237f75bd36c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip new file mode 100644 index 00000000000..c9500a1d03d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip new file mode 100644 index 00000000000..3a213b6a76d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip new file mode 100644 index 00000000000..526625594bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip new file mode 100644 index 00000000000..155d14f6df7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip new file mode 100644 index 00000000000..3f90311218b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip new file mode 100644 index 00000000000..8c9be1eefe1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip new file mode 100644 index 00000000000..e38d6adf2ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip new file mode 100644 index 00000000000..21d521abdcd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip new file mode 100644 index 00000000000..abb7de42e88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip new file mode 100644 index 00000000000..80b87189a9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip new file mode 100644 index 00000000000..e973fe505b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip new file mode 100644 index 00000000000..69856a46bb5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip new file mode 100644 index 00000000000..72ec5b33008 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip new file mode 100644 index 00000000000..19076243442 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip new file mode 100644 index 00000000000..de13791a91f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip new file mode 100644 index 00000000000..0b9b5772cd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip new file mode 100644 index 00000000000..4d02471bf8f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip new file mode 100644 index 00000000000..f3dc0e43f13 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip new file mode 100644 index 00000000000..5e557689795 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip new file mode 100644 index 00000000000..779433cfd7e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip new file mode 100644 index 00000000000..8f05e1243fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip new file mode 100644 index 00000000000..5dba5da8184 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip new file mode 100644 index 00000000000..28fa3b5d1cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip new file mode 100644 index 00000000000..a7555228dfa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip new file mode 100644 index 00000000000..45a37392130 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip new file mode 100644 index 00000000000..ac951973b6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip new file mode 100644 index 00000000000..f8a74991530 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip new file mode 100644 index 00000000000..13d7502c3a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip new file mode 100644 index 00000000000..9654803a875 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip new file mode 100644 index 00000000000..9d12efc0749 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip new file mode 100644 index 00000000000..de697b2f2c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip new file mode 100644 index 00000000000..2cc4d9d90a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip new file mode 100644 index 00000000000..27aca3a502c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip new file mode 100644 index 00000000000..b6b58223819 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip new file mode 100644 index 00000000000..c2d920704a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip new file mode 100644 index 00000000000..e3e7d8ca723 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip new file mode 100644 index 00000000000..a6e79f9bf0c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip new file mode 100644 index 00000000000..b1609b190fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip new file mode 100644 index 00000000000..8173e29193c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip new file mode 100644 index 00000000000..6cadf5fbdd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip new file mode 100644 index 00000000000..dd32abcd1e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip new file mode 100644 index 00000000000..3c007c9e372 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip new file mode 100644 index 00000000000..0244535a25a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip new file mode 100644 index 00000000000..b755f2652f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip new file mode 100644 index 00000000000..bfe25191b32 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip new file mode 100644 index 00000000000..1580ca2a5d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip new file mode 100644 index 00000000000..5e6c8d14342 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip new file mode 100644 index 00000000000..6e06047e93a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip new file mode 100644 index 00000000000..44a58feba4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip new file mode 100644 index 00000000000..06104fd7b83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip new file mode 100644 index 00000000000..67d05f1be2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip new file mode 100644 index 00000000000..4aac45ef838 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip new file mode 100644 index 00000000000..364b252ab1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip new file mode 100644 index 00000000000..7f9835a1a67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip new file mode 100644 index 00000000000..309f85c7f83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip new file mode 100644 index 00000000000..681e2734383 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip new file mode 100644 index 00000000000..911884f18d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip new file mode 100644 index 00000000000..735e3461e03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip new file mode 100644 index 00000000000..f31a6366ab5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip new file mode 100644 index 00000000000..da226878f4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip new file mode 100644 index 00000000000..2b519281fc3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip new file mode 100644 index 00000000000..6e496a908dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip new file mode 100644 index 00000000000..5bf0ce3079a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip new file mode 100644 index 00000000000..6f5ddf20cbf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip new file mode 100644 index 00000000000..f0e55dd85db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip new file mode 100644 index 00000000000..01bac69a55b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip new file mode 100644 index 00000000000..8fcd9523af0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip new file mode 100644 index 00000000000..de30f89448d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip new file mode 100644 index 00000000000..6e8215d41b5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip new file mode 100644 index 00000000000..854b06d1b4b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip new file mode 100644 index 00000000000..468d51dcd22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip new file mode 100644 index 00000000000..50210fe32ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip new file mode 100644 index 00000000000..77cb2deae9a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip new file mode 100644 index 00000000000..9c64d867fe5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip new file mode 100644 index 00000000000..f2caaa8f0d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip new file mode 100644 index 00000000000..84c378952ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip new file mode 100644 index 00000000000..1c2886477b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip new file mode 100644 index 00000000000..d93a3dab005 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip new file mode 100644 index 00000000000..ec6495c8bb0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip new file mode 100644 index 00000000000..badef10c656 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip new file mode 100644 index 00000000000..504de596f41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip new file mode 100644 index 00000000000..c1447654d3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip new file mode 100644 index 00000000000..4b6115fcd04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip new file mode 100644 index 00000000000..fdb8da65bbe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip new file mode 100644 index 00000000000..de7d4a18f0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip new file mode 100644 index 00000000000..d2fa64da9e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip new file mode 100644 index 00000000000..2d86919fe42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip new file mode 100644 index 00000000000..3118085b950 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip new file mode 100644 index 00000000000..97c43eac251 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip new file mode 100644 index 00000000000..e76974e6def --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip new file mode 100644 index 00000000000..63d99059cec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip new file mode 100644 index 00000000000..7f0922b1a0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip new file mode 100644 index 00000000000..606feadb0e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip new file mode 100644 index 00000000000..64d84451ed8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip new file mode 100644 index 00000000000..35bc73101c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip new file mode 100644 index 00000000000..dad0aefdc8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip new file mode 100644 index 00000000000..7ef4bb17825 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip new file mode 100644 index 00000000000..06746262630 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip new file mode 100644 index 00000000000..4330c6ede23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip new file mode 100644 index 00000000000..a8387a6b859 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip new file mode 100644 index 00000000000..f023946e8d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip new file mode 100644 index 00000000000..0f2fd246d49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip new file mode 100644 index 00000000000..12eb566a4a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip new file mode 100644 index 00000000000..8581e0ed61b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip new file mode 100644 index 00000000000..b5fb86f5f56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip new file mode 100644 index 00000000000..042fc02f207 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip new file mode 100644 index 00000000000..e0285cef988 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip new file mode 100644 index 00000000000..7c40cc29684 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip new file mode 100644 index 00000000000..84a28047235 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip new file mode 100644 index 00000000000..1db14e6f717 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip new file mode 100644 index 00000000000..08a8e86c512 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip new file mode 100644 index 00000000000..a15cc853279 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip new file mode 100644 index 00000000000..206ee07ae74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip new file mode 100644 index 00000000000..5103e37e5cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip new file mode 100644 index 00000000000..f580b6bc6ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip new file mode 100644 index 00000000000..0849a2c6380 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip new file mode 100644 index 00000000000..2cbeb6876a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip new file mode 100644 index 00000000000..c58cb2d1f3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip new file mode 100644 index 00000000000..06b3d532e34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip new file mode 100644 index 00000000000..f17c1d62882 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip new file mode 100644 index 00000000000..cba1f9164df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip new file mode 100644 index 00000000000..62d93f596d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip new file mode 100644 index 00000000000..e53a12d64ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip new file mode 100644 index 00000000000..8658271826c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip new file mode 100644 index 00000000000..389c1d5340f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip new file mode 100644 index 00000000000..b9354ed7dd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip new file mode 100644 index 00000000000..20337081732 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip new file mode 100644 index 00000000000..b44989396c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip new file mode 100644 index 00000000000..349d2b5d5c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip new file mode 100644 index 00000000000..fc4eaf24627 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip new file mode 100644 index 00000000000..a1a5e869ef3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip new file mode 100644 index 00000000000..d8a8fb80030 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip new file mode 100644 index 00000000000..34543da0b0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip new file mode 100644 index 00000000000..f4d7ae841c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip new file mode 100644 index 00000000000..82265715130 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip new file mode 100644 index 00000000000..94176d35b88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip new file mode 100644 index 00000000000..a3f727c71b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip new file mode 100644 index 00000000000..d16a0f228f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip new file mode 100644 index 00000000000..c760dc0c539 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip new file mode 100644 index 00000000000..d7034adb329 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip new file mode 100644 index 00000000000..44bf0392983 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip new file mode 100644 index 00000000000..b477771fb5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip new file mode 100644 index 00000000000..f0a89505838 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip new file mode 100644 index 00000000000..8ff56f38102 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip new file mode 100644 index 00000000000..e30f1a1927e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip new file mode 100644 index 00000000000..c0f84899ac5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip new file mode 100644 index 00000000000..ff3d0628f2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip new file mode 100644 index 00000000000..d8baeba584c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip new file mode 100644 index 00000000000..2900490d783 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip new file mode 100644 index 00000000000..bc725551953 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip new file mode 100644 index 00000000000..39c83dc91af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip new file mode 100644 index 00000000000..b13364aa13c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip new file mode 100644 index 00000000000..3fa6911d7de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip new file mode 100644 index 00000000000..fbeed9e8bab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip new file mode 100644 index 00000000000..4520e939e63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip new file mode 100644 index 00000000000..870b19560a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip new file mode 100644 index 00000000000..e3a7e8eb65e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip new file mode 100644 index 00000000000..1babb48d2ea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip new file mode 100644 index 00000000000..1a309d87657 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip new file mode 100644 index 00000000000..d3da295a237 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip new file mode 100644 index 00000000000..112edb20d12 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip new file mode 100644 index 00000000000..15d9834ce7a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip new file mode 100644 index 00000000000..1f216241217 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip new file mode 100644 index 00000000000..3f35cd96001 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip new file mode 100644 index 00000000000..fbd7b54057a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip new file mode 100644 index 00000000000..155a6049668 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip new file mode 100644 index 00000000000..4f09e38bc85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip new file mode 100644 index 00000000000..46d477547a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip new file mode 100644 index 00000000000..32d64eb6f56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip new file mode 100644 index 00000000000..fa466e7cc75 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip new file mode 100644 index 00000000000..3c45952d2d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip new file mode 100644 index 00000000000..5fc8edd83a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip new file mode 100644 index 00000000000..bcd0166a22d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip new file mode 100644 index 00000000000..3f007f131c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip new file mode 100644 index 00000000000..d1df360bbcc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip new file mode 100644 index 00000000000..d79b86cd336 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip new file mode 100644 index 00000000000..c4e6e1b8e8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip new file mode 100644 index 00000000000..498bc125abf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip new file mode 100644 index 00000000000..19ff2f35e99 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip new file mode 100644 index 00000000000..bd6e14286bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip new file mode 100644 index 00000000000..26e0c0a3b05 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip new file mode 100644 index 00000000000..e5c376daa8c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip new file mode 100644 index 00000000000..bf0cd8eecce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip new file mode 100644 index 00000000000..26717f20819 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip new file mode 100644 index 00000000000..b149181876e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip new file mode 100644 index 00000000000..891b5d35057 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip new file mode 100644 index 00000000000..90b09dd4d9a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip new file mode 100644 index 00000000000..df1b9f1e97c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip new file mode 100644 index 00000000000..fcb82a0e604 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip new file mode 100644 index 00000000000..b576966ede2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip new file mode 100644 index 00000000000..57f19f7cc94 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip new file mode 100644 index 00000000000..3616371a25b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip new file mode 100644 index 00000000000..a61e1532a18 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip new file mode 100644 index 00000000000..04c347066e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip new file mode 100644 index 00000000000..6d0812a7ab0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip new file mode 100644 index 00000000000..b291dd76b35 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip new file mode 100644 index 00000000000..d74ad5dafa4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip new file mode 100644 index 00000000000..ce0ee132217 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip new file mode 100644 index 00000000000..47c92f36f55 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip new file mode 100644 index 00000000000..fab6d688bcc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip new file mode 100644 index 00000000000..5dc190b447d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip new file mode 100644 index 00000000000..cea6f89c5d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip new file mode 100644 index 00000000000..f634a21a052 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip new file mode 100644 index 00000000000..9830d036ca8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip new file mode 100644 index 00000000000..88135712368 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip new file mode 100644 index 00000000000..e8543aa3640 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip new file mode 100644 index 00000000000..bfdd7adfdec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip new file mode 100644 index 00000000000..61f7d97e707 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip new file mode 100644 index 00000000000..38dee8f172e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip new file mode 100644 index 00000000000..9e79206b7bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip new file mode 100644 index 00000000000..8962a6918bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip new file mode 100644 index 00000000000..16e339d9356 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip new file mode 100644 index 00000000000..0d558baa021 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip new file mode 100644 index 00000000000..b347eae247d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip new file mode 100644 index 00000000000..cc28ddb979d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip new file mode 100644 index 00000000000..c7af660b732 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip new file mode 100644 index 00000000000..7665e4bb3ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip new file mode 100644 index 00000000000..cc154d8bbcf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip new file mode 100644 index 00000000000..9af65064c5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip new file mode 100644 index 00000000000..f5c3ef50b15 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip new file mode 100644 index 00000000000..70a4e712867 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip new file mode 100644 index 00000000000..81793791340 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip new file mode 100644 index 00000000000..0d405aba7ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip new file mode 100644 index 00000000000..9346ac1d32c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip new file mode 100644 index 00000000000..92fbcde9d8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip new file mode 100644 index 00000000000..1f7e407864d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip new file mode 100644 index 00000000000..9bcf6ad9ff7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip new file mode 100644 index 00000000000..c202ed8b7a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip new file mode 100644 index 00000000000..b0dedc50e79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip new file mode 100644 index 00000000000..edc00c4b123 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip new file mode 100644 index 00000000000..2749b00dc64 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip new file mode 100644 index 00000000000..13f81636bf3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip new file mode 100644 index 00000000000..25854794b08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip new file mode 100644 index 00000000000..26df8389117 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip new file mode 100644 index 00000000000..4d37fd26387 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip new file mode 100644 index 00000000000..03107c712a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip new file mode 100644 index 00000000000..c0e69ef42e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip new file mode 100644 index 00000000000..228cd503785 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip new file mode 100644 index 00000000000..bf43a6fbfb4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip new file mode 100644 index 00000000000..9800f89bbfd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip new file mode 100644 index 00000000000..3be783819ea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip new file mode 100644 index 00000000000..23f0e69bd4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip new file mode 100644 index 00000000000..8daa2a9efd9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip new file mode 100644 index 00000000000..f4e72e89537 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip new file mode 100644 index 00000000000..b1bd3ff5c76 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip new file mode 100644 index 00000000000..a51d4c5c23c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip new file mode 100644 index 00000000000..2f2c89a1188 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip new file mode 100644 index 00000000000..b1946d6c24a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip new file mode 100644 index 00000000000..13993a552fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip new file mode 100644 index 00000000000..fb2720ff39c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip new file mode 100644 index 00000000000..58ae3c39fcf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip new file mode 100644 index 00000000000..6010154f029 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip new file mode 100644 index 00000000000..393ae7e3f12 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip new file mode 100644 index 00000000000..9ac0f714315 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip new file mode 100644 index 00000000000..2ab46ef081b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip new file mode 100644 index 00000000000..39e85e7ddad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip new file mode 100644 index 00000000000..76fe2c03186 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip new file mode 100644 index 00000000000..c9a585a170c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip new file mode 100644 index 00000000000..3addb4329e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip new file mode 100644 index 00000000000..bd0bdf1993b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip new file mode 100644 index 00000000000..af4ca0cdd9d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip new file mode 100644 index 00000000000..275608fa511 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip new file mode 100644 index 00000000000..0c63b898e7d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip new file mode 100644 index 00000000000..26ce1457026 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip new file mode 100644 index 00000000000..dace41abd3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip new file mode 100644 index 00000000000..a784b433b88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip new file mode 100644 index 00000000000..a024563da30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip new file mode 100644 index 00000000000..637d3ef0cc0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip new file mode 100644 index 00000000000..a7219ad0dbe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip new file mode 100644 index 00000000000..fc02d2c2a9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip new file mode 100644 index 00000000000..d9d8e7ec6c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip new file mode 100644 index 00000000000..5a14831ce68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip new file mode 100644 index 00000000000..6a4c999990e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip new file mode 100644 index 00000000000..7c731eb83cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip new file mode 100644 index 00000000000..f217112dc33 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip new file mode 100644 index 00000000000..199d3a13f17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip new file mode 100644 index 00000000000..2e1c1aa05a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip new file mode 100644 index 00000000000..3d3f5a67f99 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip new file mode 100644 index 00000000000..60737953fb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip new file mode 100644 index 00000000000..42ee427842e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip new file mode 100644 index 00000000000..14cf83ee896 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip new file mode 100644 index 00000000000..002dbccee2e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip new file mode 100644 index 00000000000..8e09dab3963 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip new file mode 100644 index 00000000000..65acf8e8803 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip new file mode 100644 index 00000000000..65836ac6f42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip new file mode 100644 index 00000000000..40ba9885c6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip new file mode 100644 index 00000000000..c0d7dcf105d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip new file mode 100644 index 00000000000..616877eba63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip new file mode 100644 index 00000000000..98a398a2dad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip new file mode 100644 index 00000000000..056ee5b30a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip new file mode 100644 index 00000000000..39324fd6e09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip new file mode 100644 index 00000000000..9b832ba6b1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip new file mode 100644 index 00000000000..6f02a47ef1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip new file mode 100644 index 00000000000..912d254bd41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip new file mode 100644 index 00000000000..631f6d6710c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip new file mode 100644 index 00000000000..cd39fec637f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip new file mode 100644 index 00000000000..47b1b0016c4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip new file mode 100644 index 00000000000..ef354a336f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip new file mode 100644 index 00000000000..164fb143b34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip new file mode 100644 index 00000000000..4576735604a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip new file mode 100644 index 00000000000..96f3ca0a42a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip new file mode 100644 index 00000000000..d4ca3b055d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip new file mode 100644 index 00000000000..d0124f68af1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip new file mode 100644 index 00000000000..46659f7fb59 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip new file mode 100644 index 00000000000..5f26cbb6af9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip new file mode 100644 index 00000000000..ff068ebf08d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip new file mode 100644 index 00000000000..aa184595a30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip new file mode 100644 index 00000000000..668aa06ab5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip new file mode 100644 index 00000000000..8e97e9ed3ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip new file mode 100644 index 00000000000..e80141ab174 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip new file mode 100644 index 00000000000..8e0f6d9cd46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip new file mode 100644 index 00000000000..48413b3f03c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip new file mode 100644 index 00000000000..54776cf48f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip new file mode 100644 index 00000000000..630108f41c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip new file mode 100644 index 00000000000..ac6c3079732 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip new file mode 100644 index 00000000000..3ba3f9fe627 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip new file mode 100644 index 00000000000..c18420b352f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip new file mode 100644 index 00000000000..4a3c485deb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip new file mode 100644 index 00000000000..815bcf0af22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip new file mode 100644 index 00000000000..64139de1e3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip new file mode 100644 index 00000000000..b0b857d26e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip new file mode 100644 index 00000000000..43b274eecde --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip new file mode 100644 index 00000000000..a6820e08447 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip new file mode 100644 index 00000000000..7b8d8bdeb39 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip new file mode 100644 index 00000000000..fd63719fa02 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip new file mode 100644 index 00000000000..444b07067e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip new file mode 100644 index 00000000000..cfa4f3dab02 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip new file mode 100644 index 00000000000..72bd02eae43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip new file mode 100644 index 00000000000..b2b4fc033c4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip new file mode 100644 index 00000000000..f01e1872b5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip new file mode 100644 index 00000000000..08035d2ca66 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip new file mode 100644 index 00000000000..64baf279f42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip new file mode 100644 index 00000000000..0e1cf08e17a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip new file mode 100644 index 00000000000..877c173a42e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip new file mode 100644 index 00000000000..3001308cb7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip new file mode 100644 index 00000000000..1896571bff3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip new file mode 100644 index 00000000000..4bad8e1b5a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip new file mode 100644 index 00000000000..e6856b66f61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip new file mode 100644 index 00000000000..6957e7bff39 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip new file mode 100644 index 00000000000..bc042bf645c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip new file mode 100644 index 00000000000..bbc4cd7e95f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip new file mode 100644 index 00000000000..619437013f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip new file mode 100644 index 00000000000..521da69e520 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip new file mode 100644 index 00000000000..aa3b11549de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip new file mode 100644 index 00000000000..cf7620a56dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip new file mode 100644 index 00000000000..26ea0d1d035 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip new file mode 100644 index 00000000000..6e646b9ee3d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip new file mode 100644 index 00000000000..2b475a9ecdb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip new file mode 100644 index 00000000000..3f713f229af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip new file mode 100644 index 00000000000..6c055ce460d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip new file mode 100644 index 00000000000..29c2a526ee1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip new file mode 100644 index 00000000000..a54a4399b8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp new file mode 100644 index 00000000000..90eb5c20869 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // kvcache mode (use same kernel as batch mode): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargsImpl(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargsImpl(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp new file mode 100644 index 00000000000..133049057d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip new file mode 100644 index 00000000000..7dbb4e1cb56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -0,0 +1,407 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +namespace pytorch_flash { + +fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t nhead_stride_q = q.stride(2); + + // k: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(2); + + // v: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(2); + + // o: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_o = out.stride(0); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(2); + + // lse: (batch_size, nheads, seqlen_q) + ck_tile::index_t batch_stride_lse = softmax_lse.stride(0); + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); + + // do: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_do = dout.stride(0); + ck_tile::index_t stride_do = dout.stride(1); + ck_tile::index_t nhead_stride_do = dout.stride(2); + + // d: (batch_size, nheads, seqlen_q) + // CK assume d share the same stride with lse + + // dq: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = dq.stride(0); + ck_tile::index_t stride_dq = dq.stride(1); + ck_tile::index_t nhead_stride_dq = dq.stride(2); + + // dk_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = dk.stride(0); + ck_tile::index_t stride_dk = dk.stride(1); + ck_tile::index_t nhead_stride_dk = dk.stride(2); + + // dv_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = dv.stride(0); + ck_tile::index_t stride_dv = dv.stride(1); + ck_tile::index_t nhead_stride_dv = dv.stride(2); + + // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, // seqlen_k_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + drop_seed_offset}; +} + +std::tuple +mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentHIPStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); // unpadded hdim + const int head_size_8x = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); + dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + uint64_t drop_seed = 1, drop_offset = 0; + drop_seed = *philox_seed.data_ptr(); + drop_offset = *philox_offset.data_ptr(); + auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); + + + if (seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_bwd_args( + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); +#endif + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip new file mode 100644 index 00000000000..f66dee9c95f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -0,0 +1,360 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + + + +namespace pytorch_flash { + + +fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + std::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (batch_size, seqlen_q, nheads, d) + // k: (batch_size, seqlen_k, nheads_k, d) + // v: (batch_size, seqlen_k, nheads_k, d) + // o: (batch_size, seqlen_q, nheads, d) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, seqlen_q) + // randval: (batch_size, nheads, seqlen_q, seqlen_k) + + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(2); + ck_tile::index_t nhead_stride_k = k.stride(2); + ck_tile::index_t nhead_stride_v = v.stride(2); + ck_tile::index_t nhead_stride_o = out.stride(2); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t batch_stride_o = out.stride(0); + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + drop_seed_offset}; +} +std::tuple +mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads xhead_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + + mask_info mask; + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + at::Tensor temp_q = q; + if (seqlenq_ngroups_swapped) { + temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + + + at::Tensor q_padded, k_padded, v_padded; + if (head_size % 8 != 0) { + q_padded = at::pad(temp_q, {0, 8 - head_size % 8}); + k_padded = at::pad(k, {0, 8 - head_size % 8}); + v_padded = at::pad(v, {0, 8 - head_size % 8}); + } + else { + q_padded = temp_q; + k_padded = k; + v_padded = v; + } + + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + } + if (head_size % 8 != 0) { out = at::empty_like(q_padded); }; + } + else { + out = at::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m -1) / m*m;}; + const int head_size_8x = round_multiple(head_size, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte)); + } + else { + p = at::empty({ 0 }, opts); + } + + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto rng_state = at::empty({2}, opts.dtype(at::kLong)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + + + + at::Tensor seed_t, offset_t; + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + + auto philox_args = gen->philox_cuda_state(counter_offset); + + + + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr); + seed_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[0])), at::dtype(at::kLong)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[1])), at::dtype(at::kLong)); + } + else + { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + if (seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_fwd_traits( + mask, + q_dtype_str, + head_size_8x, + has_dropout, + has_lse, + alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_fwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); +#endif + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} //namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip new file mode 100644 index 00000000000..708096392d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip @@ -0,0 +1,436 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + + +namespace pytorch_flash { + + +fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int max_seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + // q: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t nhead_stride_q = q.stride(1); + + // k: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t nhead_stride_k = k.stride(1); + + // v: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t nhead_stride_v = v.stride(1); + + // o: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_o = 0; + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t nhead_stride_o = out.stride(1); + + // lse: (nheads, total_q) + ck_tile::index_t batch_stride_lse = 0; + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0); + + // do: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_do = 0; + ck_tile::index_t stride_do = dout.stride(0); + ck_tile::index_t nhead_stride_do = dout.stride(1); + + // d: (batch_size, nheads, max_seqlen_q) + // CK assume d share the same stride with lse + + // dq: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = 0; + ck_tile::index_t stride_dq = dq.stride(0); + ck_tile::index_t nhead_stride_dq = dq.stride(1); + + + // dk_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = 0; + ck_tile::index_t stride_dk = dk.stride(0); + ck_tile::index_t nhead_stride_dk = dk.stride(1); + + // dv_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = 0; + ck_tile::index_t stride_dv = dv.stride(0); + ck_tile::index_t nhead_stride_dv = dv.stride(1); + + // dq_acc: (split, total_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = 0; + ck_tile::index_t stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_k_ptr + total_q, + total_k, + b, + max_seqlen_q, // max_seqlen_q + max_seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + drop_seed_offset}; +} + +std::tuple +mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size_8x = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, total_q, num_heads, head_size_8x); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, total_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); + dq_accum = at::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if(zero_tensors) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + uint64_t drop_seed = 1, drop_offset = 0; + + drop_seed = *philox_seed.data_ptr(); + drop_offset = *philox_offset.data_ptr(); + auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); + + if (max_seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_varlen_bwd_args( + mask, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); +#endif + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip new file mode 100644 index 00000000000..1ffdcd5b9bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -0,0 +1,364 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +namespace pytorch_flash { + +fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (total_q, nheads, d) + // k: (total_k, nheads_k, d) + // v: (total_k, nheads_k, d) + // o: (total_q, nheads, d) + + // alibi_slopes:(batch, nheads) or (nhead) + // lse: (batch, nheads, max_seqlen_q) + // randval: (nheads, total_q, max_seqlen_k) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(1); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t batch_stride_o = 0; + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_kpads + total_q, + total_k, + b, + max_seqlen_q, + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + drop_seed_offset}; +} + +std::tuple +mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional & /*seqused_k*/, + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + // TODO - Support paged_KV + // const bool paged_KV = block_table_.has_value(); + // TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + const int max_num_blocks_per_seq = 0; + const int num_blocks = 0; + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + + // TODO + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + + const int total_q = q.size(0); + const int total_k = k.size(0); + + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = at::pad(q, {0, 8 - head_size_og % 8}); + k_padded = at::pad(k, {0, 8 - head_size_og % 8}); + v_padded = at::pad(v, {0, 8 - head_size_og % 8}); + } + else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } + else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(at::kByte)); + } + + if (zero_tensors) + { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_dropout_randval) {p.zero_();} + } + + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto rng_state = at::empty({2}, opts.dtype(at::kLong)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr); + } + + + if (max_seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_fwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); +#endif + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + //return kludge -- TODO:: REMOVE + at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); + at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); + + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt new file mode 100644 index 00000000000..78f844fd2a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt @@ -0,0 +1,1810 @@ +fmha_bwd_api.hip -> fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2.hip -> fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_deterministic.hip -> fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_pd.hip -> fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_ps.hip -> fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_psd.hip -> fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_ps.hip -> fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_ps_deterministic.hip -> fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_psd.hip -> fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_psd_deterministic.hip -> fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2.hip -> fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_deterministic.hip -> fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_pd.hip -> fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_ps.hip -> fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_psd.hip -> fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_ps.hip -> fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_ps_deterministic.hip -> fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_psd.hip -> fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_psd_deterministic.hip -> fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2.hip -> fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_ps.hip -> fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip +fmha_bwd_dot_do_o_d128_bf16_group_o2_ps.hip -> fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip +fmha_bwd_dot_do_o_d128_bf16_group_o2_psdv.hip -> fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2.hip -> fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_ps.hip -> fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip +fmha_bwd_dot_do_o_d128_fp16_group_o2_ps.hip -> fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip +fmha_bwd_dot_do_o_d128_fp16_group_o2_psdv.hip -> fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2.hip -> fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_ps.hip -> fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip +fmha_bwd_dot_do_o_d256_bf16_group_o2_ps.hip -> fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip +fmha_bwd_dot_do_o_d256_bf16_group_o2_psdv.hip -> fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2.hip -> fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_ps.hip -> fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip +fmha_bwd_dot_do_o_d256_fp16_group_o2_ps.hip -> fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip +fmha_bwd_dot_do_o_d256_fp16_group_o2_psdv.hip -> fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2.hip -> fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_ps.hip -> fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip +fmha_bwd_dot_do_o_d32_bf16_group_o2_ps.hip -> fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip +fmha_bwd_dot_do_o_d32_bf16_group_o2_psdv.hip -> fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2.hip -> fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_ps.hip -> fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip +fmha_bwd_dot_do_o_d32_fp16_group_o2_ps.hip -> fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip +fmha_bwd_dot_do_o_d32_fp16_group_o2_psdv.hip -> fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2.hip -> fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_ps.hip -> fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip +fmha_bwd_dot_do_o_d64_bf16_group_o2_ps.hip -> fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip +fmha_bwd_dot_do_o_d64_bf16_group_o2_psdv.hip -> fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2.hip -> fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_ps.hip -> fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip +fmha_bwd_dot_do_o_d64_fp16_group_o2_ps.hip -> fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip +fmha_bwd_dot_do_o_d64_fp16_group_o2_psdv.hip -> fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip +fmha_fwd_api.hip -> fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr.hip -> fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi.hip -> fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_dropout.hip -> fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse.hip -> fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse_dropout.hip -> fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask.hip -> fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_dropout.hip -> fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse.hip -> fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_dropout.hip -> fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse.hip -> fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse_dropout.hip -> fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask.hip -> fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_dropout.hip -> fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse.hip -> fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse_dropout.hip -> fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr.hip -> fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi.hip -> fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_dropout.hip -> fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse.hip -> fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse_dropout.hip -> fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask.hip -> fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_dropout.hip -> fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse.hip -> fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_dropout.hip -> fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse.hip -> fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse_dropout.hip -> fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask.hip -> fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_dropout.hip -> fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse.hip -> fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse_dropout.hip -> fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh new file mode 100644 index 00000000000..0dc441e87ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -ex + +file_renaming_txt="rename_ck_autogen_files.output.txt" +rm -rf $file_renaming_txt +for file in `ls fmha_*wd*hip`; do + sha1=$(sha1sum $file | cut -d' ' -f1) + new_file="fmha_ck_autogen_${sha1}.hip" + mv $file $new_file + echo "$file -> $new_file" >> $file_renaming_txt +done diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp new file mode 100644 index 00000000000..85754c03787 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 00000000000..9d4252ad6ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,503 @@ +#pragma once +#include + +#include +#include +#include + + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#if defined(USE_CK_FLASH_ATTENTION) +// CK implementation +TORCH_API +std::tuple +mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +inline std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_fwd_ck(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } else { + return mha_fwd_aot(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + + } +#else + return mha_fwd_aot(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif +} + +inline std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_varlen_fwd_ck(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } else { + return mha_varlen_fwd_aot(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } +#else + return mha_varlen_fwd_aot(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif + +} + + +inline std::tuple +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_bwd_ck(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } else { + return mha_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + return mha_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif + +} + +inline std::tuple +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_varlen_bwd_ck(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } else { + return mha_varlen_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + return mha_varlen_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif +} + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp new file mode 100644 index 00000000000..3ad4766f6e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp @@ -0,0 +1,53 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace flash { +inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state) +{ + // Imitate from PyTorch + // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 + if (arg.captured_) { + rng_state[0] = static_cast(*arg.seed_.ptr); + rng_state[1] = static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_); + } else { + rng_state[0] = arg.seed_.val; + rng_state[1] = arg.offset_.val; + } +} + + +} // namespace flash diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3825b80bd84..2ecfa5a8197 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1380,6 +1380,9 @@ if(USE_ROCM) if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) endif() + if(USE_CK_FLASH_ATTENTION) + target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION) + endif() endif() if(BUILD_LITE_INTERPRETER) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 9355e01aad9..b46560e123b 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -130,6 +130,7 @@ function(caffe2_print_configuration_summary) if(${USE_ROCM}) message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") + message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 6d3500c8542..de11a3c9574 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -69,6 +69,8 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.preferred_blas_library +.. autofunction:: torch.backends.cuda.preferred_rocm_fa_library + .. autofunction:: torch.backends.cuda.preferred_linalg_library .. autoclass:: torch.backends.cuda.SDPAParams diff --git a/test/test_transformers.py b/test/test_transformers.py index e291a6c7956..715fbe4297b 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3317,6 +3317,10 @@ class TestSDPACudaOnly(NNTestCase): if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and dropout_p != 0: + self.skipTest("CK does not support tensor format dropout masks") + if TEST_WITH_ROCM and head_dim > 128: + self.skipTest("CK does not support head dims over 128") scale = scale if scale is None else (1 / head_dim) num_heads_q = num_heads_kv = 4 diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 26af7eec1fb..f2ae6200d2c 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -101,7 +101,6 @@ includes = [ "aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h", - "aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h", "aten/src/THC/*", "aten/src/ATen/test/*", # CMakeLists.txt isn't processed by default, but there are a few diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d7d5bd50e6..31a00510c63 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1288,6 +1288,14 @@ class _BlasBackend: Cublaslt: _BlasBackend Ck: _BlasBackend +def _get_rocm_fa_preferred_backend() -> torch._C._ROCmFABackend: ... +def _set_rocm_fa_preferred_backend(arg: torch._C._ROCmFABackend): ... + +class _ROCmFABackend: + Default: _ROCmFABackend + AOTriton: _ROCmFABackend + Ck: _ROCmFABackend + class ConvBackend(Enum): ... class Tag(Enum): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d5b992add75..53defbd20fa 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -616,6 +616,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._get_function_stack_at", "torch._C._get_graph_executor_optimize", "torch._C._get_linalg_preferred_backend", + "torch._C._get_rocm_fa_preferred_backend", "torch._C._get_math_sdp_enabled", "torch._C._get_math_sdp_allow_fp16_bf16_reduction", "torch._C._get_max_operator_version", @@ -1144,6 +1145,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._set_grad_enabled", "torch._C._set_graph_executor_optimize", "torch._C._set_linalg_preferred_backend", + "torch._C._set_rocm_fa_preferred_backend", "torch._C._set_meta_in_tls_dispatch_include", "torch._C._set_mkldnn_enabled", "torch._C._set_multithreading_enabled", @@ -2424,6 +2426,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.backends.cuda.enable_cudnn_sdp", "torch.backends.cuda.preferred_blas_library", "torch.backends.cuda.preferred_linalg_library", + "torch.backends.cuda.preferred_rocm_fa_library", "torch.backends.cuda.sdp_kernel", "torch.backends.cudnn._init", "torch.backends.cudnn.flags", diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 2b7aa449466..b305819c1b0 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -14,6 +14,7 @@ __all__ = [ "cuBLASModule", "preferred_linalg_library", "preferred_blas_library", + "preferred_rocm_fa_library", "cufft_plan_cache", "matmul", "SDPAParams", @@ -264,9 +265,57 @@ def preferred_blas_library( return torch._C._get_blas_preferred_backend() +_ROCmFABackends = { + "default": torch._C._ROCmFABackend.Default, + "aotriton": torch._C._ROCmFABackend.AOTriton, + "ck": torch._C._ROCmFABackend.Ck, +} +_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys()) + + from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend +def preferred_rocm_fa_library( + backend: Union[None, str, torch._C._ROCmFABackend] = None +) -> torch._C._ROCmFABackend: + r""" + [ROCm-only] + Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK + + .. warning:: This flag is experimeental and subject to change. + + When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend. + This flag (a :class:`str`) allows users to override this backend to use composable_kernel + + * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton. + * If `"aotriton"` is set then AOTriton will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK + globally. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _ROCmFABackends: + raise RuntimeError( + "Unknown input value. " f"Choose from: {_ROCmFABackends_str}." + ) + torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) + elif isinstance(backend, torch._C._ROCmFABackend): + torch._C._set_rocm_fa_preferred_backend(backend) + else: + raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.") + + return torch._C._get_rocm_fa_preferred_backend() + + # Set the __module__ attribute SDPAParams.__module__ = "torch.backends.cuda" SDPAParams.__name__ = "SDPAParams" diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 044199db29b..2230b15aeb3 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -15,6 +15,7 @@ #include #include #include + #include #include #include @@ -108,6 +109,7 @@ #include #ifdef USE_CUDA +#include #include #include #ifdef __HIP_PLATFORM_AMD__ @@ -2192,6 +2194,18 @@ Call this whenever a new thread is created in order to propagate values from return at::globalContext().blasPreferredBackend(); }); + py::enum_(py_module, "_ROCmFABackend") + .value("Default", at::ROCmFABackend::Default) + .value("AOTriton", at::ROCmFABackend::AOTriton) + .value("Ck", at::ROCmFABackend::Ck); + + py_module.def("_set_rocm_fa_preferred_backend", [](at::ROCmFABackend b) { + at::globalContext().setROCmFAPreferredBackend(b); + }); + py_module.def("_get_rocm_fa_preferred_backend", []() { + return at::globalContext().getROCmFAPreferredBackend(); + }); + py_module.def( "_construct_storage_from_data_pointer", [](int64_t data_ptr, c10::Device device, size_t size_bytes) {